# Masker Model

Léo Dupire

## Load Data

In [3]:
import os
import random
import tqdm.auto as tqdm
import torch
import numpy as np
import imageio.v3 as iio
import matplotlib.pyplot as plt

## PyTorch
from torch import nn, optim
from torch.utils.data import Dataset
from torchsummary import summary
import torch.optim.lr_scheduler as lr_scheduler

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

In [4]:
#@title Import Tensor Data
# Train
imgs = torch.load('data/imgs.pt')
masks = torch.load('data/masks.pt')

# Validation
val_imgs = torch.load('data/val_imgs.pt')
val_masks = torch.load('data/val_masks.pt')

# Print shapes of tensors
print("Train imgs:", imgs.shape)
print("Val imgs:", val_imgs.shape)
print()
print("Train masks:", masks.shape)
print("Val masks:", val_masks.shape)

In [7]:
# Display an Image-Mask pair from the validation set
fig, axes = plt.subplots(1, 2, figsize=(6.4*2, 4.8))
axes[0].imshow(val_imgs[0][0])
axes[1].imshow(val_masks[0][0])
plt.show()

## Dataset Augmentation and Loaders

### For Train & Val:

In [8]:
#@title Dataset Class & Loaders
class MaskDataset(Dataset):
  def __init__(self, imgs, masks, transform=None):
    self.imgs = imgs.reshape(-1, 160, 240, 3)
    self.masks = masks.reshape(-1, 160, 240)
  
  def __len__(self):
    return len(self.masks)

  def __getitem__(self, index):
    img = self.imgs[index].to(torch.uint8)
    msk = self.masks[index]
    img = img.permute(2, 0, 1).to(torch.float) / 255
    img = (img - 0.5) / 2

    if random.random() > 0.5: # Random Horizontal Flip on both image and corresponding mask
      img = torch.flip(img, dims=[2])
      msk = torch.flip(msk, dims=[1])

    return (img, msk)

# Datasets
train_dataset = MaskDataset(imgs, masks,)
val_dataset = MaskDataset(val_imgs, val_masks,)

# Data Loaders
batch_size = 64

train_loader = torch.utils.data.DataLoader(
    train_dataset, batch_size=batch_size, shuffle=True, drop_last=True)

val_loader = torch.utils.data.DataLoader(
    val_dataset, batch_size=batch_size, shuffle=False, drop_last=False)

## UNet - Masker

Code inspired by: https://pyimagesearch.com/2021/11/08/u-net-training-image-segmentation-models-in-pytorch/

In [13]:
#@title Conv Block
class conv_block(nn.Module):
    def __init__(self, in_c, out_c):
        super().__init__()

        self.conv1 = nn.Conv2d(in_c, out_c, kernel_size=3, padding=1)
        self.bn1 = nn.BatchNorm2d(out_c)

        self.conv2 = nn.Conv2d(out_c, out_c, kernel_size=3, padding=1)
        self.bn2 = nn.BatchNorm2d(out_c)

        self.relu = nn.ReLU()

    def forward(self, inputs):
        x = self.conv1(inputs)
        x = self.bn1(x)
        x = self.relu(x)

        x = self.conv2(x)
        x = self.bn2(x)
        x = self.relu(x)

        return x

In [14]:
#@title Encoder Block
class encoder_block(nn.Module):
    def __init__(self, in_c, out_c):
        super().__init__()

        self.conv = conv_block(in_c, out_c)
        self.pool = nn.MaxPool2d((2, 2))

    def forward(self, inputs):
        x = self.conv(inputs)
        p = self.pool(x)

        return x, p

In [15]:
#@title Decoder Block
class decoder_block(nn.Module):
    def __init__(self, in_c, out_c):
        super().__init__()

        self.up = nn.ConvTranspose2d(in_c, out_c, kernel_size=2, stride=2, padding=0)
        self.conv = conv_block(out_c+out_c, out_c)

    def forward(self, inputs, skip):
        x = self.up(inputs)
        x = torch.cat([x, skip], axis=1)
        x = self.conv(x)

        return x

In [16]:
#@title UNet Block
class UNet(nn.Module):
    def __init__(self):
        super().__init__()

        # Encoder
        self.e1 = encoder_block(3, 64)
        self.e2 = encoder_block(64, 128)
        self.e3 = encoder_block(128, 256)
        self.e4 = encoder_block(256, 512)

        # Bottleneck
        self.b = conv_block(512, 1024)

        # Decoder
        self.d1 = decoder_block(1024, 512)
        self.d2 = decoder_block(512, 256)
        self.d3 = decoder_block(256, 128)
        self.d4 = decoder_block(128, 64)

        # Classifier
        self.outputs = nn.Conv2d(64, 49, kernel_size=1, padding=0)

    def forward(self, inputs):
        # Encoder
        s1, p1 = self.e1(inputs)
        s2, p2 = self.e2(p1)
        s3, p3 = self.e3(p2)
        s4, p4 = self.e4(p3)

        # Bottleneck
        b = self.b(p4)

        # Decoder
        d1 = self.d1(b, s4)
        d2 = self.d2(d1, s3)
        d3 = self.d3(d2, s2)
        d4 = self.d4(d3, s1)

        # Classifier
        outputs = self.outputs(d4)

        return outputs

### (Option 1) Create a New Model

In [18]:
#@title New Model
model = UNet().to(device)

# Test forward pass
input_tensor = train_dataset[0][0].unsqueeze(0).to(device) 
output = model(input_tensor)
print(output.shape)

### (Option 2) Load a Trained Model

In [19]:
#@title Load a Model
model = UNet().to(device)
model.load_state_dict(torch.load("./masker_models/masker.pth"))

# Test forward pass
input_tensor = train_dataset[0][0].unsqueeze(0).to(device) 
output = model(input_tensor)
print(output.shape)

### Model Specs

In [24]:
summary(model, input_size=(3, 160, 240))

## Train

In [None]:
optimizer = optim.Adam(model.parameters(), lr=1e-4) # You can even start at lr=1e-3
scheduler = lr_scheduler.ReduceLROnPlateau(optimizer, 'min', factor=0.1, patience=3, verbose=True)
criterion = nn.CrossEntropyLoss()

result = {"train": [], "val": []} # For tracking loss
best_val = -1 # Negative as real val will be positive

In [None]:
#@title Display Mask Output Comparison
def display_comp(model, index, ds="val", show_img=False):
  if ds == "val":
    example_image, example_mask = val_dataset[index]
  elif ds == "train":
    example_image, example_mask = train_dataset[index]

  pred_mask = model(example_image.unsqueeze(0).to(device)).cpu().squeeze(0)
  
  if show_img:
    example_image = (((example_image.permute(1, 2, 0)*2) + 0.5) * 255).to(int)

    fig, axes = plt.subplots(1, 3, figsize=(12, 6))
    axes[0].imshow(example_image, vmin=0, vmax=48)
    axes[1].imshow(example_mask, vmin=0, vmax=48)
    axes[2].imshow(pred_mask.argmax(0), vmin=0, vmax=48)
  else:
    fig, axes = plt.subplots(1, 2, figsize=(6, 3))
    axes[0].imshow(example_mask, vmin=0, vmax=48)
    axes[1].imshow(pred_mask.argmax(0), vmin=0, vmax=48)
  plt.show()

In [None]:
#@title Training
from tqdm.notebook import tqdm

num_epochs = 10

# Loss + Update Model function
def get_loss(image, mask, optimizer=None):
  pred_mask = model(image)
  loss = criterion(pred_mask, mask.long())

  if optimizer is not None:
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
  return loss

# Training loop
for epoch in tqdm(range(1, num_epochs+1), leave=False): # Train on several epochs
  total_train_loss = 0
  model.train()
  for batch in tqdm(train_loader, leave=False):
    image, mask = [x.to(device) for x in batch]
    total_train_loss += get_loss(image, mask, optimizer=optimizer) # Get the loss and update

  train_loss = total_train_loss.item() / len(train_loader)
  result["train"].append(train_loss) # Record loss for post-training visualization

  # Test on validation
  with torch.no_grad():
    val_result = 0
    count = 0
    model.eval()
    total_val_loss = 0
    for batch in val_loader:
      image, mask = [x.to(device) for x in batch]
      total_val_loss += get_loss(image, mask) * image.size(0) # Get the loss
      count += image.size(0)

    val_result = total_val_loss.item() / count
    result["val"].append(val_result) # Record loss for post-training visualization
    print(f"Epoch {epoch} | Train: {train_loss:.4f} | Val: {val_result:.4f}") # Print Epoch losses

    if (best_val == -1) or (val_result < best_val):
        best_val = val_result
        torch.save(model.state_dict(), "./masker_models/best_masker.pth") # Save best model

    display_comp(model, 40) # Display prediction example on validation (function defined in cell above)
  scheduler.step(total_val_loss) # Send validation loss to lr_scheduler

## Results

In [None]:
# View Training Progress
fig_prog = plt.figure(figsize=(6,4))

plt.plot(range(1, len(result["train"])+1), result["train"], label="Train")
plt.plot(range(1, len(result["val"])+1), result["val"], label="Val")
plt.title("Reconstruction error over epoch", fontsize=14)
plt.xlabel("Epoch")
plt.ylabel("Loss")
plt.legend()
plt.show()

In [None]:
# View a specific example
example_num = 4
model.eval()
display_comp(model, example_num, ds="val", show_img=True)

## Generate Masks for Unlabeled Dataset

### For Unlabeled: i.e. _Lazy Loading_ (Optional)

__Only to be run once the Masker is fully trained!__ We will use the Masker to generate masks for the unlabeled data.

In [9]:
P = "./data/Dataset_Student" # Unlabeled data directory

# Get sorted list of videos in unlabeled folder
dir_list = os.listdir(f"{P}/unlabeled/")
lst1 = [x for x in dir_list if len(x) == 10]
lst2 = [x for x in dir_list if len(x) == 11]
lst1.sort()
lst2.sort()
dirs = lst1 + lst2
len(dirs)

13000

In [10]:
class LazyDataset(Dataset):
    def __init__(self, dir_list=None):
        self.data_files = dir_list # Loads a video folder
        
    def __len__(self):
        return len(self.data_files)

    def __getitem__(self, index):
        video = []
        for i in range(22): # Extract all images from corresponding video folder
            img = torch.Tensor(np.copy(iio.imread(f"{P}/unlabeled/{self.data_files[index]}/image_{i}.png"))).to(torch.uint8)
            img = img.permute(2, 0, 1).to(torch.float) / 255
            img = (img - 0.5) / 2
            video.append(img)
        video_imgs = torch.stack(video)
        
        return video_imgs

unlabeled_dataset = LazyDataset(dir_list=dirs)
unlabeled_loader = torch.utils.data.DataLoader(unlabeled_dataset, num_workers=8) # Lazy Loader

### Generate & Save Masks

In [28]:
# Create 'empty' tensor
unlabeled_masks = torch.zeros([13000, 22, 160, 240]) # For rough memory availability verification
unlabeled_masks.shape

In [None]:
#@title Unlabeled Images
count = 0
model.eval()
for batch in tqdm.tqdm(unlabeled_loader): # Go through all examples
    inp = batch.squeeze(0).to(device)
    masks = model(inp)
    unlabeled_masks[count] = masks.argmax(1).unsqueeze(0) # Record in order
    count += 1 # Keep track of order (index)

torch.save(unlabeled_masks, '/data/unlabeled_masks.pt') # Save unlabeled masks as 'unlabeled_masks.pt' in ~/WNet/data