Imports

In [None]:
# Pytorch imports
import torch
from torch.utils.data import DataLoader
from torchvision import datasets
from torchvision.transforms import ToTensor
from torch import nn
import torchmetrics
import torchinfo

# Utils imports
import numpy as np
import os
import matplotlib.pyplot as plt

GNLDataLoader

In [2]:
import os
import dlib
import cv2
import numpy as np
import torch
from torch.utils.data import Dataset
from torchnlp.encoders import LabelEncoder

import test

debug_dl = True

class GNLDataLoader(Dataset):
    """Creates a dataloader for the Lipsync Project"""
    face_detector = dlib.get_frontal_face_detector()
    landmark = dlib.shape_predictor("shape_predictor_68_face_landmarks_GTX.dat")

    alphabet = [x for x in "abcdefghijklmnopqrstuvwxyz0123456789 "]
    encoder = LabelEncoder(alphabet, reserved_labels=['unknown'], unknown_index=0)
    CROPMARGIN = 20

    def __init__(self, labels_path: str, data_path: str, transform = None, train_test_percent: int = 75, debug: bool = False) -> None:
        """
        Creates a dataset given the path to the labels and the image directory

        Parameters:
            - `labels_path`: the path to the `csv` file containing the labels;
            - `images_dir`: the path to the directory with the images;
            - `transform`: states whether a transformation should be applied to the images or not.
        """
        super().__init__()
        self.debug: bool = debug

        if self.debug:
            print(f"[DEBUG] The data dir has{' ' if os.path.isdir(data_path) else ' not '}been recognized")
            print(f"[DEBUG] The label dir has{' ' if os.path.isdir(labels_path) else ' not '}been recognized")

        self.data_path, self.labels_path = data_path, labels_path
        self.data_dir, self.labels_dir = sorted(os.listdir(data_path)), sorted(os.listdir(labels_path))
        self.transform = transform
        

    def __len__(self) -> int:
        """
        Returns the length of the data/labels folder
        
        Returns:
            - `length` (`int`): the length of the data/labels folder
        """
        return len(self.data_dir)
    

    def __getitem__(self, index: int, straight: bool = False) -> tuple[torch.Tensor, torch.Tensor]:
        """
        Get the ith item(s) in the dataset
        
        Parameters:
            - `index`: the index of the image that must be retrieven.
            
        Returns:
            - (`item`, `label`) (`tuple[torch.Tensor, torch.Tensor]`): the item in the ith position in the dataset, along with its label.
        """

        if self.debug:
            print(f"[DEBUG] Index of the dataloader: {index}")
            print(f"[DEBUG] Data folder: {self.data_dir[index]}")
            print(f"[DEBUG] Labels folder: {self.labels_dir[index]}")

        self.data_dir[index] = [self.data_dir[index]] if type(self.data_dir[index]) != list else self.data_dir[index]
        self.labels_dir[index] = [self.labels_dir[index]] if type(self.labels_dir[index]) != list else self.labels_dir[index]

        return (
            [self.__load_video__(data_piece) for data_piece in self.data_dir[index]],
            [self.__load_label__(label_piece) for label_piece in self.labels_dir[index]]
        )


    def __load_video__(self, video_path: str) -> torch.Tensor:
        """
        Loads a video from the dataset given its path
        
        Parameters:
            - `video_path`: the path of the video that must be loaded
            
        Returns:
            - `video` (`torch.Tensor`): the video as a PyTorch's `Tensor`
        """
        video_path = os.path.join(self.data_path, video_path)
        cap = cv2.VideoCapture(video_path)
        if self.debug:
            #print(cap.get(cv2.CAP_PROP_FRAME_COUNT))
            print(f"[DEBUG] Trying to open the video at path {video_path}")
        to_return = np.ndarray(shape =(75,100,150))

        homog, prev_frame = True, None

        for i in range(int(cap.get(cv2.CAP_PROP_FRAME_COUNT))):
            ret, frame = cap.read()
            gframe = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY)# .astype('uint8')  # Format to 8-bit image. 'int8' doesn't seem to do the job either

            if self.debug:
                '''
                cv2.imshow("Frame", gframe)
                cv2.waitKey(0)
                cv2.destroyAllWindows()
                cv2.imwrite("/workspace/GUNILEO/tests/gframe001.jpg", gframe)'''
                
                prev_frame = gframe.shape if prev_frame == None else prev_frame
                homog = False if prev_frame != gframe.shape else True
                print(gframe.shape, homog)
                

            facedetect = self.face_detector(gframe)
            
            #HAVE A CHECK IF THE FACE IS FOUND OR NOT



            face_landmarks = self.landmark(gframe, facedetect[0])
            xleft = face_landmarks.part(48).x - self.CROPMARGIN
            xright = face_landmarks.part(54).x + self.CROPMARGIN
            ybottom = face_landmarks.part(57).y + self.CROPMARGIN
            ytop = face_landmarks.part(50).y - self.CROPMARGIN

            mouth = gframe[ytop:ybottom, xleft:xright]
            mouth = cv2.resize(mouth, (150, 100))
            
            mean = np.mean(mouth)
            std_dev = np.std(mouth)
            mouth = (mouth - mean) / std_dev
            to_return[i] = torch.tensor(mouth)
            
        cap.release()
        return to_return
    

    def __load_label__(self, label_path: str) -> torch.Tensor:
        """
        Loads a label from the dataset given its path

        Parameters:
            - `label_path`: the path of the label that must be loaded;

        Returns:
            - `label` (`torch.Tensor`): the label as a PyTorch's tensor
        """
        
        encoding = [
            {"b":"bin","l":"lay","p":"place","s":"set"},
            {"b":"blue","g":"green","r":"red","w":"white"},
            {"a":"at","b":"by","i":"in","w":"with"},
            "letter",
            {"z":"zero","1":"one","2":"two","3":"three","4":"four","5":"five","6":"six","7":"seven","8":"eight","9":"nine"},
            {"a":"again","n":"now","p":"please","s":"soon"}
            ]
        
        code = label_path.split(".")[0].split("_")[-1]
        print(code)
        sentence = []
        for i, letter in enumerate(code):
            corresponding_dict = encoding[i]
            next = letter if corresponding_dict == "letter" else corresponding_dict[letter]
            sentence = sentence + [" "] + [x for x in next]
        enl = self.encoder.batch_encode(sentence)
        if self.debug: print(enl)
        return enl

CNN

In [3]:
import torch
from torch import nn
import torchinfo

class SelectItem(nn.Module):
    def __init__(self, item_index):
        super(SelectItem, self).__init__()
        self._name = 'selectitem'
        self.item_index = item_index

    def forward(self, inputs):
        
        return inputs[self.item_index]
class LabialCNN(nn.Module):
    def __init__(self, debug: bool = False):
        super().__init__()

        self.debug = debug
        self.cnn = nn.Sequential(
            nn.Conv3d(in_channels=1,out_channels=8,kernel_size=(3, 5, 5),padding=(1, 2, 2),stride=(1, 2, 2)),
            nn.ReLU(),
            nn.MaxPool3d(kernel_size=(1, 2, 2),stride=(1, 2, 2)),
        
            nn.Conv3d( in_channels=8,out_channels=16,kernel_size=(3, 5, 5),padding=(1, 2, 2),stride=(1, 1, 1)),
            nn.ReLU(),
            nn.MaxPool3d(kernel_size=(1, 2, 2),stride=(1, 2, 2)),
        
            nn.Conv3d( in_channels=16,out_channels=32,kernel_size=(3, 5, 5),padding=(1, 2, 2),stride=(1, 1, 1)),
            nn.ReLU(),
            nn.MaxPool3d(kernel_size=(1, 2, 2),stride=(1, 2, 2)),
             # Left as default, check later if it causes problems
            
        )    
        self.gru = nn.Sequential(
            nn.GRU(input_size= 1728,hidden_size= 256,num_layers=2, dropout=0.5 ,bidirectional=True),
            SelectItem(0),
            
            nn.Linear(in_features=512,out_features= 37),
            nn.Softmax()
        )

    # Remember to put FALSE
    def forward(self, x):
        x = self.cnn(x) # Run through the model
        
        sh = x.shape
        x = torch.reshape(x,(sh[1],sh[0],sh[2],sh[3])) # Reshape so that the channels are flattened, not frames
        x = nn.Flatten()(x)
        x = self.gru(x)
      
        
        if self.debug: print(f"Layer's shape: {sh}")
        #x = torch.flatten(x, 1)     # Flatten layer
        #if debug: print(f"  Layer's shape: {x.shape}")
        if self.debug: print(f"Summary of the layer: a")

Loops

In [4]:
import torchmetrics
import torch

metric = torchmetrics.Accuracy(task="multiclass", num_classes=37)

def train_loop(device, dataloader, model, loss_fn, optimizer, epochs, epoch=None, debug=True):
    """Trains an epoch of the model
    
    Parameters:
        - `device`: destination device
        - `dataloader`: the dataloader of the dataset
        - `model`: the model used
        - `loss_fn`: the loss function of the model
        - `optimizer`: the optimizer
        - `epoch`: the index of the epoch
    """
    size = len(dataloader)

    # Get the batch from the dataset
    for batch, (x, y) in enumerate(dataloader):
        # Move data to the device used
        x = x.to(device)
        y = y.to(device)

        # Compute the prediction and the loss
        pred = model(x)
        loss = loss_fn(pred, y)

        # Adjust the weights
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()

        # Print some information
        if batch % 32 == 0:
            loss_value, current_batch = loss.item(), (batch + 1) * len(x)
            if debug: print(f"→ Loss: {loss_value} [Batch {current_batch}/{size}, Epoch {epoch}/{epochs}]")
            accuracy = metric(pred, y)
            if debug: print(f"Accuracy of batch {current_batch}/{size}: {accuracy}")
        
    accuracy = metric.compute()
    print(f"=== The epoch {epoch}/{epochs} has finished training ===")
    if debug: print(f"→ Final accuracy of the epoch: {accuracy}")
    metric.reset()

def test_loop(device, dataloader, model, loss_fn, debug=True):
    size = len(dataloader)

    # Disable the updating of the weights
    with torch.no_grad():
        for index, (x, y) in enumerate(dataloader):
            # Move the data to the device used for testing
            x = x.to(device)
            y = y.to(device)

            # Get the model prediction
            pred = model(x)

            # Get the accuracy score
            acc = metric(pred, y)
            if debug: print(f"→ Accuracy for image {index}: {acc}")
    acc = metric.compute()
    print(f"===    The testing loop has finished    ===")
    if debug: print(f"→ Final testing accuracy of the model: {acc}")
    metric.reset()

Data Loading

In [None]:
# Create the dataloaders of our project
path_data = "data/matching/fronts" # "data/lombardgrid_front/lombardgrid/front"
path_labels = "data/matching/labels" # "data/lombardgrid_alignment/lombardgrid/alignment"

dataset = GNLDataLoader(path_labels, path_data, transform=None, debug=False)

# Test
print(
    f"[DEBUG] Items in the data folder: {len(sorted(os.listdir(path_data)))}",
    f"[DEBUG] Items in the labels folder: {len(sorted(os.listdir(path_labels)))}",
    sep="\n"
)

dataloader_train = DataLoader(dataset[0:128], batch_size=32, shuffle=True)
dataloader_test = DataLoader(dataset[128:192], batch_size=32, shuffle=True)

Model + Hyperparameters

In [None]:
device = "cuda" if torch.cuda.is_available() else "cpu"

model = LabialCNN(debug=False).to(device)

# Print the summary of the model
torchinfo.summary(model, (1,75, 100, 150), col_names = ("input_size", "output_size", "num_params", "kernel_size", "mult_adds"), verbose = 1)

epochs = 2
batch_size = 50
learning_rate = 10 ** (-4)
dropout = 0.5

loss_fn = nn.CTCLoss(reduction="mean")
optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)

Training + Testing

In [None]:
for epoch_ind in range(epochs):
    train_loop(dataloader_train, model, loss_fn, optimizer, epoch_ind, debug=False)
    test_loop(dataloader_test, model, loss_fn, debug=False)

print("=== The training has finished ===")