In [None]:
import torch
from torchvision import datasets

import math
import numpy as np
import matplotlib.pyplot as plt

from rembg import remove, new_session

from data.dataset import MaskedDataset

In [None]:
if torch.backends.mps.is_available():
    device = torch.device('mps')
elif torch.cuda.is_available():
    device = torch.device('cuda')
else:
    device = torch.device('cpu')

print('Running on {}'.format(device))

In [None]:
seed = 0

torch.manual_seed(seed)
np.random.seed(seed)

generator = torch.Generator()
generator.manual_seed(seed)

# torch.use_deterministic_algorithms(True)

In [None]:
batch_size = 100

# Load data

In [None]:
data_main = datasets.STL10(root = "./data/STL10/", split = 'unlabeled')

In [None]:
train_dataset = MaskedDataset(data_main)

print(len(train_dataset))

In [None]:
train_loader = torch.utils.data.DataLoader(
    dataset = train_dataset,
    batch_size = batch_size,
    shuffle = False,
    generator = generator
)

# Rembg

In [None]:
segment_session = new_session()

file_counter = 0
for _, full_res, _, _ in train_loader:
    predicted_segmentation = []

    for i in range(full_res.shape[0]):
        segment_mask = remove((full_res[i].permute(1, 2, 0).cpu().numpy() * 255).astype(np.uint8), session=segment_session, only_mask=True, post_process_mask=False)
        segment_mask = torch.tensor(segment_mask / 255, dtype=torch.float32)  # device=device

        # Heuristic to filter bad masks
        # comment the if statement if you want all segmentation masks
        # if len(segment_mask[segment_mask >= 0.8]) >= 100:
        #     torch.save(segment_mask, './data/STL10_segmentations_filtered/segmentation_{}.pt'.format(file_counter))

        torch.save(segment_mask, './data/STL10_segmentations/segmentation_{}.pt'.format(file_counter))
        predicted_segmentation.append(segment_mask)
        file_counter += 1

    if file_counter % 100 == 0:
        print(file_counter)

    # break

predicted_segmentation = torch.stack(predicted_segmentation, axis=0)
print(predicted_segmentation.shape)

In [None]:
fig, axs = plt.subplots(10, 10, layout='constrained', figsize=(20, 20))

index = 0

for row in range(10):
    for col in range(10):
        axs[row, col].imshow((full_res[index].permute(1, 2, 0).cpu().numpy() * 255).astype(np.uint8))
        axs[row, col].imshow(predicted_segmentation[index].cpu().numpy(), alpha=0.4)  # cmap='hot'
        # axs[row, col].imshow(predicted_segmentation[index][0].cpu().detach().numpy(), alpha=0.4)  # cmap='hot'

        axs[row, col].get_xaxis().set_visible(False)
        axs[row, col].get_yaxis().set_visible(False)
        
        index += 1

## Show insufficient segmentation masks

In [None]:
insufficient_masks = []
insufficient_indices = []

for index in range(len(train_dataset)):
    # all segmentations (unfiltered)
    loaded_segmentation = torch.load('./data/STL10_segmentations/segmentation_{}.pt'.format(index), map_location=device)

    if len(loaded_segmentation[loaded_segmentation >= 0.8]) <= 100:
        insufficient_indices.append(index)
        insufficient_masks.append(loaded_segmentation)

        print(index)

In [None]:
print(len(insufficient_indices))

In [None]:
index = 0

for batch in range(math.ceil(len(insufficient_indices) / 100)):
    fig, axs = plt.subplots(10, 10, layout='constrained', figsize=(20, 20))

    for row in range(10):
        for col in range(10):
            if index < len(insufficient_indices):
                axs[row, col].imshow((train_dataset[insufficient_indices[index]][1].permute(1, 2, 0).cpu().numpy() * 255).astype(np.uint8))
                axs[row, col].imshow(insufficient_masks[index].cpu().numpy(), alpha=0.5, cmap='viridis')

                axs[row, col].set_title(insufficient_indices[index])

                axs[row, col].get_xaxis().set_visible(False)
                axs[row, col].get_yaxis().set_visible(False)
                
                index += 1
            else:
                break

    plt.show()