![Open Screen](https://github.com/Slenderman00/open_screen/blob/master/media/banner.png?raw=true)

### Downloads the dataset

Remember to install kaggle and set the kaggle api key `pip install kaggle`

In [None]:
import os
import subprocess
from IPython.display import clear_output
from shutil import rmtree


cwd = os.getcwd()
main_dataset_path = f'{cwd}/dataset'
if not os.path.exists(main_dataset_path):
    process = subprocess.Popen('bash download_dataset.sh', shell=True, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, universal_newlines=True)
    for stdout_line in iter(process.stdout.readline, ""):
        clear_output(wait=True)
        print(stdout_line, end="")
    process.stdout.close()
    process.wait()
    assert process.returncode == 0
    print('Finished downloading and extracting')

    dataset_len = len(os.listdir(f'{cwd}/dataset/image')) # Remember to do - 1 to get the actual index
    dataset_path = f'{cwd}/supervisely_person_clean_2667_img/supervisely_person_clean_2667_img'

    # append data from supervisely_person_clean_2667_img to dataset
    for idx, file in enumerate(os.listdir(f'{dataset_path}/images')):
        file_png = file.split('.')[0] + '.png'

        os.rename(f'{dataset_path}/masks/{file_png}', f'{main_dataset_path}/mask/{idx + dataset_len - 1}.png')
        os.rename(f'{dataset_path}/images/{file}', f'{main_dataset_path}/image/{idx + dataset_len - 1}.jpg')

    rmtree(f'{cwd}/supervisely_person_clean_2667_img')

    dataset_path = f'{cwd}/segmentation_full_body_tik_tok_2615_img/segmentation_full_body_tik_tok_2615_img'

    # append data from segmentation_full_body_tik_tok_2615_img to dataset
    for idx, file in enumerate(os.listdir(f'{dataset_path}/images')):
        file_png = file.split('.')[0] + '.png'
    
        os.rename(f'{dataset_path}/masks/{file_png}', f'{main_dataset_path}/mask/{idx + dataset_len - 1}.png')
        os.rename(f'{dataset_path}/images/{file}', f'{main_dataset_path}/image/{idx + dataset_len - 1}.jpg')
    
    rmtree(f'{cwd}/segmentation_full_body_tik_tok_2615_img')

### Calculate mean and std

In [None]:
import os
from PIL import Image
import torch
from torchvision import transforms

cwd = os.getcwd()

sums = torch.zeros(3)
squared_sums = torch.zeros(3)
count = 0

for image in os.listdir(f'{cwd}/dataset/image'):
    image = f'{cwd}/dataset/image/{image}'
    image = Image.open(image).convert('RGB')

    image_tensor = transforms.ToTensor()(image)

    sums += image_tensor.sum(dim=[1, 2])
    squared_sums += (image_tensor ** 2).sum(dim=[1, 2])
    count += image_tensor.size(1) * image_tensor.size(2)

    del image
    del image_tensor

mean = sums / count
std = (squared_sums / count - mean ** 2) ** 0.5

print(f'Mean: {mean}')
print(f'Std: {std}')

### Create a custom dataset loader for pytorch

In [None]:
import os
from torchvision import transforms
from torch.utils.data import Dataset
from PIL import Image

image_transform = transforms.Compose([
    transforms.Resize((512, 512)),  # Resize the images to a fixed size
    transforms.ToTensor(),  # Convert the images to tensors
    transforms.Normalize(mean, std), # Normalise data
])

mask_transform = transforms.Compose([
    transforms.Resize((512, 512)),   # Resize the masks to a fixed size
    transforms.Grayscale(),          # Convert masks to grayscale if they are not already
    transforms.ToTensor(),        # Convert the masks to tensors
])

class CustomLoader(Dataset):
    def __init__(self, device='cpu'):
        self.device = device
        self.cwd = os.getcwd()
        self.len = len(os.listdir(f'{self.cwd}/dataset/image'))

    def __len__(self):
        return self.len

    def __getitem__(self, idx):
        image_path = f'{self.cwd}/dataset/image/{idx}.jpg'
        mask_path = f'{self.cwd}/dataset/mask/{idx}.png'

        image = Image.open(image_path).convert('RGB')
        mask = Image.open(mask_path).convert('L')  # Convert mask to grayscale

        # Apply the transformations
        image_tensor = image_transform(image).float()
        mask_tensor = mask_transform(mask).float()

        # If mask is not single-channel, take the first channel
        if mask_tensor.size(0) != 1:
            mask_tensor = mask_tensor[0].unsqueeze(0)

        # Move tensors to the specified device
        image_tensor = image_tensor.to(self.device)
        mask_tensor = mask_tensor.to(self.device)

        return {'images': image_tensor, 'masks': mask_tensor}

### U-Net 
After the disappointing performance of the last model we have decided to implement a variant of the U-Net architecture

![u-net](https://upload.wikimedia.org/wikipedia/commons/2/2b/Example_architecture_of_U-Net_for_producing_k_256-by-256_image_masks_for_a_256-by-256_RGB_image.png)

(https://en.wikipedia.org/wiki/U-Net)

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class SimpleUnet(nn.Module):
    def __init__(self):
        super(SimpleUnet, self).__init__()

        # Down conv
        self.down_conv1 = nn.Conv2d(3, 256, kernel_size=3, padding=1)
        self.pool1 = nn.MaxPool2d(2, 2)
        self.down_conv2 = nn.Conv2d(256, 128, kernel_size=3, padding=1)
        self.pool2 = nn.MaxPool2d(2, 2)
        self.down_conv3 = nn.Conv2d(128, 64, kernel_size=3, padding=1)
        self.pool3 = nn.MaxPool2d(2, 2)
        self.down_conv4 = nn.Conv2d(64, 32, kernel_size=3, padding=1)

        # Up conv
        self.upsample1 = nn.Upsample(64)
        self.up_conv1 = nn.Conv2d(64, 3, kernel_size=3, padding=1)
        self.upsample2 = nn.Upsample(128)
        self.up_conv2 = nn.Conv2d(128, 3, kernel_size=3, padding=1)
        self.upsample3 = nn.Upsample(256)
        self.up_conv3 = nn.Conv2d(256, 3, kernel_size=3, padding=1)

    def forward(self, x):
        x1 = F.relu(self.down_conv1(x))
        x1p = self.pool1(x1)
        x2 = F.relu(self.down_conv2(x1p))
        x2p = self.pool2(x2)
        x3 = F.relu(self.down_conv3(x2p))
        x3p = self.pool3(x3)
        x4 = F.relu(self.down_conv4(x3p))

        x5u = self.upsample1(x4)
        # print(x4.size(), x3.size())
        x5c = torch.cat([x5u, x3], dim=1)
        x6 = F.relu(self.up_conv1(x5c))
        x6u = self.upsample2(x6)
        x6c = torch.cat([x6u, x2], dim=1)
        x7 = F.relu(self.up_conv2(x6c))
        x7u = self.upsample3(x7)
        x7c = torch.concat([x7u, x1], dim=1)

        x8 = F.relu(self.up_conv3(x7c))

        x_out = torch.sigmoid(x8)
        return x_out



In [None]:
import torch.nn as nn
import torch.optim as optim

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
#model = SimpleSegmentationModel().to(device)
model = SimpleUnet().to(device)

criterion = nn.BCEWithLogitsLoss()  # Binary Cross-Entropy Loss for binary classification
optimizer = optim.Adam(model.parameters(), lr=0.0004)


### Settings

- Batch_size: sets the DataLoaders batch size
- Debug: Enables debug mode (detailed view of what happens inside the layers during inference)

In [None]:
batch_size = 2 
debug = False # Only works when batch size is one
debug_activations = False

In [None]:
from torch.utils.data import DataLoader, random_split

train_dataset = CustomLoader(device=device)

validation_size = int(0.2 * len(train_dataset))
train_size = len(train_dataset) - validation_size

train_subset, validation_subset = random_split(train_dataset, [train_size, validation_size])

train_loader = DataLoader(train_subset, batch_size=batch_size, shuffle=True)
validation_loader = DataLoader(validation_subset, batch_size=batch_size, shuffle=True)

In [None]:
from IPython.display import clear_output
import matplotlib.pyplot as plt
import numpy as np


def visualize_activations(activations):
    for name, activation in activations.items():
        num_feature_maps = activation.size(1)
        
        num_feature_maps_to_display = min(num_feature_maps, 15)
        
        fig, axes = plt.subplots(1, num_feature_maps_to_display, figsize=(20, 2))
        if num_feature_maps_to_display == 1:
            axes = [axes]
            
        for i in range(num_feature_maps_to_display):
            # We detach the tensor from the GPU and convert to numpy for visualization
            feature_map = activation[0, i].detach().cpu().numpy()
            axes[i].set_title(name)
            axes[i].imshow(feature_map, cmap='gray')
            axes[i].axis('off')

def plot_loss(loss_values, val_loss_values):
    clear_output(wait=True)

    plt.figure(figsize=(10, 5))
    plt.plot(loss_values, label='Training Loss')
    plt.xlabel('Iterations')
    plt.ylabel('Loss')
    plt.legend()

    plt.figure(figsize=(10, 5))
    plt.plot(val_loss_values, label='Validation Loss')
    plt.xlabel('Iterations')
    plt.ylabel('Loss')
    plt.legend()

def plot_masks(image, true_mask, pred_mask, epoch, idx, loss_values, val_loss_values):
    plot_loss(loss_values, val_loss_values)

    fig, axs = plt.subplots(1, 3, figsize=(10, 5))

    image_np = image.squeeze(0).cpu().numpy().transpose(1, 2, 0)
    pred_mask_np = pred_mask.squeeze().detach().cpu().numpy()


    if true_mask is not None:
        true_mask_np = true_mask.squeeze().cpu().numpy()
        axs[1].imshow(true_mask_np, cmap='gray')
        axs[1].set_title('Ground Truth Mask')
        axs[1].axis('off')

    axs[0].imshow(image_np)
    axs[0].set_title('Original Image')
    axs[0].axis('off') 
    
    axs[2].imshow(pred_mask_np, cmap='gray')
    axs[2].set_title('Predicted Mask')
    axs[2].axis('off')
    
    plt.suptitle(f'Epoch: {epoch}, Index: {idx}')

loss_values = []
val_loss_values = []

for epoch in range(100):
    model.train()
    tra_loss = 0
    for i, sample in enumerate(train_loader):
        images, masks = sample['images'], sample['masks']
        optimizer.zero_grad()
        if (i + epoch) % 500 == 0 and debug:
            activations, outputs = model.forward_with_activations(images)
        else:
            outputs = model(images)  

        loss = criterion(outputs, masks) 
        loss.backward()
        optimizer.step()

        tra_loss += loss.item()

        if (i + epoch) % 500 == 0 and debug:

            pred_mask = outputs > 0.5  # Apply threshold to get binary mask
            plot_masks(images, masks, pred_mask, epoch+1, i+1, loss_values, val_loss_values)
            print(f'Epoch [{epoch+1}], Loss: {loss.item()}')

        if debug_activations:
            visualize_activations(activations)

        plt.show()

    if not debug:
        tra_loss /= len(train_loader)
        loss_values.append(tra_loss)
        plot_loss(loss_values, val_loss_values)
        print(f'Epoch [{epoch+1}], Loss: {tra_loss}')
 
    torch.save(model.state_dict(), f'backups/epoch{epoch}.pth')



    # Validation step
    model.eval()  # Set the model to evaluation mode
    with torch.no_grad():  # No need to track gradients for validation
        val_loss = 0
        for i, sample in enumerate(validation_loader):
            images, masks = sample['images'], sample['masks']
            activations, outputs = model.forward_with_activations(images)
            loss = criterion(outputs, masks)
            val_loss += loss.item()

            if (i + epoch) % 100 == 0 and debug:
                pred_mask = outputs > 0.5  # Apply threshold to get binary mask
                plot_masks(images, masks, pred_mask, epoch+1, i+1, loss_values, val_loss_values)
                print('In Validation')

            if debug_activations:
                visualize_activations(activations)

            plt.show()


        val_loss /= len(validation_loader)
        val_loss_values.append(val_loss)
        print(f'Epoch [{epoch+1}], Validation Loss: {val_loss}')

In [None]:
# save the model
torch.save(model.state_dict(), 'cnn.pth')
