# Data exploration

## I- Import relevant librairires

In [1]:
import os
import sys
import torch
import matplotlib.pyplot as plt
import matplotlib.patches as patches

Below, you can import the different modules coded in ``WeakSegNet/src``

In [None]:
BASE_DIR = os.path.abspath(os.path.join(os.getcwd(), "..")) # This is the repo directory: WeakSegNet/
SRC_PATH = os.path.join(BASE_DIR, "src")
sys.path.append(SRC_PATH)

from dataset import OxfordPet, data_loading

## II- Load the data

In [None]:
# Sanity check: are the trainval.txt and test.txt files available

FILE_PATH = os.path.join(BASE_DIR, "data")

split = "trainval" # or "test"
print(f"Looking for {split}.txt in:", os.path.join(FILE_PATH, "annotations", f"{split}.txt"))
print("Does the file exists?", os.path.exists(os.path.join(FILE_PATH, "annotations", "test.txt")))

## III- Data visualisation

### 1- Fully supervised case

In [None]:
image_size = 256
batch_size_train, batch_size_val, batch_size_test = 16, 16, 16
val_split = 0.2
train_dataset, val_dataset, test_dataset, train_loader, val_loader, test_loader = data_loading(
                                                                    path=FILE_PATH,
                                                                    fully_supervised=True,
                                                                    weakly_supervised=False,
                                                                    image_size=image_size,
                                                                    batch_size_train=batch_size_train,
                                                                    batch_size_val=batch_size_val,
                                                                    batch_size_test=batch_size_test,
                                                                    val_split = val_split
                                                                    )

In [None]:
image, mask, info = train_dataset[35]

In [None]:
plt.subplot(1, 2, 1)
plt.imshow(image.permute(1, 2, 0))
plt.title("Image")
plt.axis('off')
plt.grid(False)

plt.subplot(1, 2, 2)
plt.imshow(mask.squeeze(), cmap="grey")
plt.title(f"Mask for image {info['name']}")
plt.axis('off')

plt.grid(False)
plt.show()

In [None]:
unique_vals = torch.unique(mask)
print("Unique mask values:", unique_vals)

### 2- Weakly supervised case

In [None]:
image_size = 256
batch_size_train, batch_size_val, batch_size_test = 16, 16, 16
val_split = 0.2
train_dataset, val_dataset, test_dataset, train_loader, val_loader, test_loader = data_loading(
                                                                    path=FILE_PATH,
                                                                    fully_supervised=False,
                                                                    weakly_supervised=True,
                                                                    image_size=image_size,
                                                                    batch_size_train=batch_size_train,
                                                                    batch_size_val=batch_size_val,
                                                                    batch_size_test=batch_size_test,
                                                                    val_split = val_split
                                                                    )

In [None]:
image, info = train_dataset[78]

In [None]:
box = info['bbox']

In [None]:
plt.imshow(image.permute(1, 2, 0))
plt.title("Image")
plt.axis('off')
plt.grid(False)
ax = plt.gca()
width = box[1] - box[0]
heigth = box[3] - box[2]
ax.add_patch(patches.Rectangle((box[0], box[2]), width, heigth, linewidth=2, edgecolor='r', facecolor='none'))
plt.show()

In [None]:
info['bbox']