In [None]:
import os
import torch
import cv2

import matplotlib.pyplot as plt
import numpy as np

from bs4 import BeautifulSoup
from PIL import Image

In [None]:
# If using Intel Arc GPU (like I am), run this cell.
import intel_extension_for_pytorch as ipex

In [None]:
DATA_ROOT_PATH = os.path.join("../", "guide3d/data/guide3d")
ANNOTATION_FILE_PATH = os.path.join(DATA_ROOT_PATH, "annotations.xml")

In [None]:
with open(ANNOTATION_FILE_PATH, 'r') as f:
    xml_data = f.read()

xml_parsed = BeautifulSoup(xml_data, "lxml")

In [None]:
# For testing, read some arbitrary image.
# some_img = os.path.join(DATA_ROOT_PATH, xml_parsed.camera["image"])
some_img = os.path.join(DATA_ROOT_PATH, "1-bca-straight-1-2", "241.png")
some_img = Image.open(some_img)

some_img.size

In [None]:
edges = cv2.adaptiveThreshold(np.array(some_img), 255, cv2.ADAPTIVE_THRESH_GAUSSIAN_C,\
 cv2.THRESH_BINARY, 11, 2)

plt.imshow(edges, cmap = "gray", vmin = 0, vmax = 255)


In [None]:
plt.imshow(np.array(some_img)[592-40:592+40, 192-40:192+40], cmap = "gray")

In [None]:
plt.imshow(np.array(some_img)[301-80:301+80, 508-80:508+80], cmap = "gray")

In [None]:
images = np.array([])


camera_xml_nodes = xml_parsed.findAll("camera")
reconstruction_xml_nodes = xml_parsed.findAll("reconstruction")

for each_camera_node in camera_xml_nodes:
    mask = np.zeros((1024, 1024))
    images = np.append(images, each_camera_node["image"])

print(images[2])


In [None]:
mask = np.zeros((1024, 1024))
x = np.array([])
y = np.array([])

for i in range(len(each_camera_node["points"].split(';'))):
    j, k = each_camera_node["points"].split(';')[i].split(',')
    print(k, j)

    x = np.append(x, int(j))
    y = np.append(y, int(k))

    mask[int(k)][int(j)] = 1
    

In [None]:
plt.imshow(np.array(mask), interpolation = "nearest", aspect = "auto", cmap = "gray")

In [None]:
plt.imshow(np.array(some_img), interpolation = "nearest", aspect = "auto")
# plt.scatter(x, y, color = "red")

# Morph points into a polyline.

x = np.int32(x)
y = np.int32(y)

mask_points = np.concatenate([x[:,None], y[:,None]], axis = 1)
mask_points = mask_points.reshape((-1, 1, 2))

closed_img = cv2.polylines(np.array(some_img), [mask_points], isClosed = False, color = (0, 0, 255), thickness = 2)

plt.imshow(closed_img, aspect = "auto")

### Image Training for Segmentation

In [None]:
import torch
import torchvision.transforms as transforms
from torch.utils.data import Dataset

from torch import nn
from torch.optim import SGD

from tqdm import tqdm

import warnings
warnings.filterwarnings("ignore")

In [None]:
mask_points = np.subtract(closed_img, np.array(some_img))
plt.imshow(mask_points, cmap = "gray")

In [None]:
def GetImagesList():
    images = np.array([])

    camera_xml_nodes = xml_parsed.findAll("camera")

    for each_camera_node in tqdm(camera_xml_nodes):
        curr_img = each_camera_node["image"]
        images = np.append(images, curr_img)

    return images

def GetPoints(image_name):
    for each_camera_node in xml_parsed.findAll("camera"):
        if (each_camera_node["image"] == image_name):
            return each_camera_node["points"]

        # x = np.array([])
        # y = np.array([])
        
        # for i in range(len(each_camera_node["points"].split(';'))):
        #     j, k = each_camera_node["points"].split(';')[i].split(',')

        #     x = np.append(x, np.int32(j))
        #     y = np.append(y, np.int32(k))

    #     mask_points = np.concatenate([x[:,None], y[:,None]], axis = 1)
    #     mask_points = mask_points.reshape((-1, 1, 2))
    #     mask_points = np.int32(mask_points)

    #     # For the purpose of forming a polyline, we need to get the image data.

    #     tmp_img = os.path.join(DATA_ROOT_PATH, curr_img)
    #     tmp_img = np.array(Image.open(tmp_img), np.int32)

    #     closed_img = cv2.polylines(tmp_img, [mask_points], isClosed = False, color = (0, 0, 255), thickness = 2)
    #     segm = np.subtract(closed_img, tmp_img)

    #     print(segm.shape)
    #     break

In [None]:
a = GetImagesList()

In [None]:
test_img = Image.open(os.path.join(DATA_ROOT_PATH, a[1000]))

plt.imshow(np.array(test_img))



In [None]:
class SegmDataLoader(Dataset):
    def __init__(self, transform = None, target_transform = None):
        self.images = GetImagesList()
        self.transform = transform
        self.target_transform = target_transform

    def __len__(self):
        return len(self.images)
    
    def __getitem__(self, idx):
        # We return both the image matrix and segmentation matrix.
        points = GetPoints(self.images[idx])

        x = np.array([])
        y = np.array([])
        
        for i in range(len(points.split(';'))):
            j, k = points.split(';')[i].split(',')

            x = np.append(x, np.int32(j))
            y = np.append(y, np.int32(k))

        mask_points = np.concatenate([x[:,None], y[:,None]], axis = 1)
        mask_points = mask_points.reshape((-1, 1, 2))
        mask_points = np.int32(mask_points)

        # For the purpose of forming a polyline, we need to get the image data.

        _tmp_img = os.path.join(DATA_ROOT_PATH, self.images[idx])
        _tmp_img = Image.open(_tmp_img)

        _closed_img = cv2.polylines(np.array(_tmp_img), [mask_points], isClosed = False, color = (255, 0, 0), thickness = 2)
        _segm = np.subtract(_tmp_img, _closed_img)

        _tmp_img = np.array(_tmp_img)
        _segm = np.where(_segm != 0, 1, 0)

        # Finally, ensure both _tmp_img and _segm are tensors of current dtype that can carry gradient information.
        _tmp_img = torch.tensor(_tmp_img, dtype = torch.float32)
        _segm = torch.tensor(_segm, dtype = torch.long)
        
        return self.images[idx], _tmp_img, _segm

In [None]:
test_dl = SegmDataLoader()

for i in test_dl:
    _, tmp_img, segm = i

    plt.imshow(segm, cmap = "gray")

    break

In [None]:
class Network(nn.Module):
    def __init__(self):
        super().__init__()

        self.linear_1 = nn.Linear(1024*1024, 256)
        self.linear_2 = nn.Linear(256, 256)
        self.linear_3 = nn.Linear(256, 1024*1024)
        self.relu = nn.ReLU()
        self.lg_softmax = nn.LogSoftmax()

    def forward(self, x):
        x = self.linear_1(x)
        x = self.relu(x)
        x = self.linear_2(x)
        x = self.relu(x)
        x = self.linear_3(x)
        x = self.lg_softmax(x)

        return x

In [None]:
# del model
# import gc
# gc.collect()

model = Network().to("xpu")
print(model)

In [None]:
from torch.autograd import Variable

test_dl = SegmDataLoader()
loss_f = nn.NLLLoss().to("xpu")
TOTAL_EPOCHS = 3

losses = np.array([])
optimiser = SGD(model.parameters(), lr=0.00001)

print(model.parameters())

with torch.set_grad_enabled(True):
    torch.cuda.empty_cache()

    for epoch in range(TOTAL_EPOCHS):
        optimiser.zero_grad()
        curr_losses = np.array([])

        i = 0
        for single_data in tqdm(test_dl):
            _, tmp_img, segm = single_data
            tmp_img = tmp_img.reshape(-1).to("xpu")
            
            Y_preds = model(tmp_img)

            segm = segm.reshape(-1).to("xpu")

            loss = loss_f(Y_preds, segm)

            loss.backward()
            optimiser.step()

            curr_losses = np.append(curr_losses, loss.cpu().detach())

            losses = np.append(losses, np.mean(curr_losses))

            if i % 500 == 0:
                print("EPOCH MIN MAX", epoch, np.min(losses), np.max(losses))


            i += 1
        if epoch == 2:
            torch.save(model, "trained_model_3_epochs")
            torch.save(model.state_dict(), "trained_model_3_epochs_statedict")

            break

    del curr_losses

### Testing

In [None]:
model = torch.load("trained_model_3_epochs")