In [None]:
# from google.colab import drive
# drive.mount('/content/drive')
!ls

In [None]:
!unzip ./training_data_filtered.zip

In [None]:
import torch
import os
import numpy as np
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
import pandas as pd
from PIL import Image
import matplotlib.pyplot as plt
import torch.nn as nn
from tqdm import tqdm,trange

# set device for PyTorch
has_gpu = torch.cuda.is_available()
has_mps = getattr(torch,'has_mps',False)
device = "mps" if getattr(torch,'has_mps',False) \
    else "gpu" if torch.cuda.is_available() else "cpu"

# set up the dataset locations
training_data_dir = "training_data_filtered/training_data/v2.2"
split_dir = "training_data_filtered/training_data/splits/v2"
objects_csv = 'training_data_filtered/training_data/objects_v1.csv'

def get_split_files(split_name):
    with open(os.path.join(split_dir, f"{split_name}.txt"), 'r') as f:
        prefix = [os.path.join(training_data_dir, line.strip()) for line in f if line.strip()]
        rgb = [p + "_color_kinect.png" for p in prefix]
        depth = [p + "_depth_kinect.png" for p in prefix]
        label = [p + "_label_kinect.png" for p in prefix]
        meta = [p + "_meta.pkl" for p in prefix]
    return rgb, depth, label, meta

rgb_files, depth_files, label_files, meta_files = get_split_files('train')
rgb_files_val, depth_files_val, label_files_val, meta_files_val = get_split_files('val')

# define the dataset class

def read_image(img_path):
    '''
    inputs:
    img_path : the location of the image to be read
    outputs:
    image converted to torch.tensor
    '''
    image = np.array(Image.open(img_path))
#     print(image)
    image = torch.from_numpy(image)
    return image

class mydataset(Dataset):
    # define the init method
    def __init__(self, annotations_files, img_files, img_dir, object_files, transform=None, target_transform = None) -> None:
        super().__init__()
        self.target_labels = annotations_files
        self.img_dir = img_dir
        self.img_files = img_files
        self.objects = pd.read_csv(object_files)
        self.transform = transform
        self.target_transform = target_transform

    # define the len method
    def __len__(self):
        return len(self.target_labels)

    # define the getitem() method
    def __getitem__(self,idx):
        img_path = self.img_files[idx]#os.path.join(self.img_dir, self.img_files[idx])
        target_path = self.target_labels[idx]#os.path.join(self.img_dir, self.target_labels[idx])
        image = read_image(img_path)/255.0 # divide by 255 or do some normalization using transforms
        label = read_image(target_path)
        if self.transform:
            image  = self.transform(image)
        if self.target_transform:
            label = self.target_transform(label)
        return image, label

training_data = mydataset(label_files,rgb_files,training_data_dir,objects_csv)
validation_data = mydataset(label_files_val,rgb_files_val,training_data_dir,objects_csv)

train_dataloader = DataLoader(training_data, batch_size=1, shuffle=True)
val_dataloader = DataLoader(validation_data, batch_size=1,shuffle=True)

# train_features, train_labels = next(iter(train_dataloader))
# val_features, val_labels = next(iter(val_dataloader))

# print(train_features.size())
# print(train_labels.size())

# print(f"max :{train_labels.max()}")

# print(val_features.size())
# print(val_labels.size())

# visualize the dataset elements and verify if they are loading correctly

# img = train_features[0].squeeze()
# label = train_labels[0]
# plt.imshow(label)
# plt.show()

# create a dataloader now or in the train routine

# define the network

class Segmentation(nn.Module): 
    def __init__(self):
        super().__init__()
        self.c1 = nn.Sequential(
            nn.Conv2d(3, 64, 3, padding=1, padding_mode="reflect"),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            # nn.Conv2d(64, 64, 3, padding=1, padding_mode="reflect"),
            # nn.BatchNorm2d(64),
            # nn.ReLU(),
        )
        self.c2 = nn.Sequential(
            nn.Conv2d(64, 128, 3, padding=1, padding_mode="reflect"),
            nn.BatchNorm2d(128),
            nn.ReLU(),
            # nn.Conv2d(128, 128, 3, padding=1, padding_mode="reflect"),
            # nn.BatchNorm2d(128),
            # nn.ReLU(),
        )
        self.c3 = nn.Sequential(
            nn.Conv2d(128, 256, 3, padding=1, padding_mode="reflect"),
            nn.BatchNorm2d(256),
            nn.ReLU(),
            # nn.Conv2d(256, 256, 3, padding=1, padding_mode="reflect"),
            # nn.BatchNorm2d(256),
            # nn.ReLU(),
        )
        self.c4 = nn.Sequential(
            nn.Conv2d(256, 512, 3, padding=1, padding_mode="reflect"),
            nn.BatchNorm2d(512),
            nn.ReLU(),
            # nn.Conv2d(512, 512, 3, padding=1, padding_mode="reflect"),
            # nn.BatchNorm2d(512),
            # nn.ReLU(),
)
        self.p1 = nn.MaxPool2d(2)
        self.p2 = nn.MaxPool2d(2)
        self.p3 = nn.MaxPool2d(2)
        self.d1 = nn.ConvTranspose2d(128, 64, 2, 2)
        self.d2 = nn.ConvTranspose2d(256, 128, 2, 2)
        self.d3 = nn.ConvTranspose2d(512, 256, 2, 2)
        self.dc1 = nn.Sequential(
            nn.Conv2d(128, 64, 3, padding=1, padding_mode="reflect"),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            # nn.Conv2d(64, 64, 3, padding=1, padding_mode="reflect"),
            # nn.BatchNorm2d(64),
            # nn.ReLU(),
            nn.Conv2d(64, 82, 1),
#             nn.BatchNorm2d(82),
#             nn.ReLU(),
        )
        self.dc2 = nn.Sequential(
            nn.Conv2d(256, 128, 3, padding=1, padding_mode="reflect"),
            nn.BatchNorm2d(128),
            nn.ReLU(),
            # nn.Conv2d(128, 128, 3, padding=1, padding_mode="reflect"),
            # nn.BatchNorm2d(128),
            # nn.ReLU(),
)
        self.dc3 = nn.Sequential(
            nn.Conv2d(512, 256, 3, padding=1, padding_mode="reflect"),
            nn.BatchNorm2d(256),
            nn.ReLU(),
            # nn.Conv2d(256, 256, 3, padding=1, padding_mode="reflect"),
            # nn.BatchNorm2d(256),
            # nn.ReLU(),
)
    def forward(self, x): 
        x1 = self.c1(x)
#         print(x1.size())
        x2 = self.c2(self.p1(x1))
#         print(x2.size())
        x3 = self.c3(self.p2(x2))
#         print(x3.size())
        x4 = self.c4(self.p3(x3))
#         print(x4.size())
        y3 = torch.cat([x3, self.d3(x4)], dim=1)
#         print(y3.size())
        y2 = torch.cat([x2, self.d2(self.dc3(y3))], dim=1)
#         print(y2.size())
        y1 = torch.cat([x1, self.d1(self.dc2(y2))], dim=1)
#         print(y1.size())
        output = self.dc1(y1).squeeze(1)
#         print(output.size())
        return output

In [None]:
# test inference on random weights

device = "cuda"
# model = UNET_mod(in_channels=3, classes=82)
model = Segmentation()
model.to(device)
# # print(model)
# inp = torch.permute(train_features,(0,3,1,2)).to(device)
# print(inp.type())
# out = model(inp)
# print(f"out: {out.size()}")

# define the optimizer and criterion
criterion = nn.CrossEntropyLoss()
optim = torch.optim.Adam(model.parameters(), lr=0.001)

# define the hyper paramteres for now lr, batch_size, num_iterations
batch_size = 1
epochs = 5
print_freq = 10
epoch_save = 100
batch_loss = []
# define the train routine
for epoch in trange(epochs):
    print_count = 0
    print_loss = 0
    epoch_step = 0
    for data in train_dataloader:
        epoch_step += 1
        print_count += 1
        result = model(torch.permute(data[0],(0,3,1,2)).to(device))
        optim.zero_grad()
        loss = criterion(result, data[1].type(torch.LongTensor).to(device))
        # print(loss.item())
        print_loss += loss.item()
        loss.backward()
        optim.step()
        # print(print_loss)
        if(print_count % print_freq == 0): 
            print(f"[{epoch+1}/{epochs}][{epoch_step}/{len(train_dataloader)}]") 
            print(f"loss: {print_loss / print_freq}")
            batch_loss.append(print_loss / print_freq)
            print_loss = 0

In [None]:
import matplotlib.pyplot as plt

a = np.arange(10,19810,10)
plt.plot(a,batch_loss)
plt.xlabel('iterations')
plt.ylabel('loss')

In [None]:
torch.save(model, f"./model_new_plot_{epoch+1}.pth")

In [None]:
device = "cuda"
model = torch.load("./model_new_5.pth")
model.to(device)
model.eval()
test = next(iter(val_dataloader))
print(test[0].size())
out_test = model(torch.permute(test[0],(0,3,1,2)).to(device))
print(out_test.size())

In [None]:
test_labels = torch.argmax(out_test,dim=1)
print(test_labels.size())

In [None]:
from matplotlib.cm import get_cmap

NUM_OBJECTS = 79
cmap = get_cmap('rainbow', NUM_OBJECTS)
COLOR_PALETTE = np.array([cmap(i)[:3] for i in range(NUM_OBJECTS + 3)])
COLOR_PALETTE = np.array(COLOR_PALETTE * 255, dtype=np.uint8)
COLOR_PALETTE[-3] = [119, 135, 150]
COLOR_PALETTE[-2] = [176, 194, 216]
COLOR_PALETTE[-1] = [255, 255, 225]
plt.imshow(COLOR_PALETTE[test_labels.squeeze(0).detach().cpu()])

In [None]:
plt.imshow(COLOR_PALETTE[test[1].squeeze(0).detach().cpu()])