In [None]:
!pip install scipy scikit-image torch tqdm transformers mediapipe opencv-python torchvision numpy pandas timm evaluate facenet-pytorch

In [1]:
from google.colab import drive
from pathlib import Path
drive.mount("/content/drive")
project_path = Path("/content/drive/MyDrive/NLP/MultiModalEmotionRecognition/")

Mounted at /content/drive


In [3]:
%cd /content/drive/MyDrive/NLP/MultiModalEmotionRecognition/data

/content/drive/MyDrive/NLP/MultiModalEmotionRecognition/data


In [4]:
!ls

correct_indexes    image_index_test.txt   sentiment_train.txt
dev.zip		   image_index_train.txt  sentiment_val.txt
english_test.txt   image_index_val.txt	  test.zip
english_train.txt  images		  train_ende.zip
english_val.txt    saved_features
error_indexes	   sentiment_test.txt


In [None]:
!unzip dev.zip
!unzip test.zip
!unzip train_ende.zip

In [None]:
%mv dev/ images/val
%mv test/ images/test
%mv train_ende/ images/train

In [57]:
%ls images/test -1 | wc -l

5068


In [58]:
%cd ..

/content/drive/MyDrive/NLP/MultiModalEmotionRecognition


In [6]:
import torch

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)

cuda:0


In [210]:
import torch
import numpy as np
import torchvision
import matplotlib.pyplot as plt
from torchvision.transforms import transforms as transforms
from torchvision.models.detection import KeypointRCNN_ResNet50_FPN_Weights


class PoseEmbeddingExtractor:
    def __init__(
        self,
        device='cpu'
    ):
        self.model = torchvision.models.detection.keypointrcnn_resnet50_fpn(weights=KeypointRCNN_ResNet50_FPN_Weights.DEFAULT,num_keypoints=17).to(device)
        self.model.eval()
        self.device = device
        self.transform = transforms.Compose([
            transforms.ToTensor()
        ])

    def extract_embedding(self, image):
        image = self.transform(image)
        image = image.unsqueeze(0).to(self.device)
        with torch.no_grad():
            outputs = self.model(image)
        
        keypoints_scores = outputs[0]['keypoints_scores']
        best_score = torch.mean(keypoints_scores, axis=1).argmax().item()
        keypoints = outputs[0]['keypoints'][best_score,:,:2]
        return keypoints.ravel()

# p = PoseEmbeddingExtractor(device=device)
# path = 'data/images/val/4965.jpg'
# img = cv2.cvtColor(cv2.imread(path), cv2.COLOR_BGR2RGB)
# p.extract_embedding(img).shape

In [214]:
import sys, os, torch, cv2
from pathlib import Path

from glob import glob
from tqdm import tqdm


def remove_non_poses(input_dir, split):
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    pee = PoseEmbeddingExtractor(device=device)
    file_name = f"pose_error_{split}.txt"
    os.makedirs("./data/error_indexes", exist_ok=True)
    non_pose_files = open(f"./data/error_indexes/{file_name}", "w")
    img_pattern = os.path.join(input_dir, "*.jpg")
    images = glob(img_pattern)
    for image_path in tqdm(images):
        try:
            img = cv2.cvtColor(cv2.imread(image_path), cv2.COLOR_BGR2RGB)
            pee.extract_embedding(img)

        except Exception as e:
            print(e)
            img_id = image_path.split("/")[-1].split(".")[0]
            non_pose_files.write(f"{img_id}")
            non_pose_files.write(os.linesep)


split = "test"
input_dir = project_path / "data" / "images" / split
remove_non_poses(input_dir, split)


 11%|█         | 533/5067 [02:15<11:11,  6.76it/s]

argmax(): Expected reduction dim to be specified for input.numel() == 0.


 17%|█▋        | 857/5067 [03:05<13:11,  5.32it/s]

argmax(): Expected reduction dim to be specified for input.numel() == 0.


 22%|██▏       | 1121/5067 [03:48<11:01,  5.97it/s]

argmax(): Expected reduction dim to be specified for input.numel() == 0.


 24%|██▍       | 1208/5067 [04:02<09:44,  6.60it/s]

argmax(): Expected reduction dim to be specified for input.numel() == 0.


 37%|███▋      | 1867/5067 [05:51<08:41,  6.14it/s]

argmax(): Expected reduction dim to be specified for input.numel() == 0.


 38%|███▊      | 1948/5067 [06:06<08:57,  5.80it/s]

argmax(): Expected reduction dim to be specified for input.numel() == 0.


 40%|███▉      | 2013/5067 [06:16<08:04,  6.31it/s]

argmax(): Expected reduction dim to be specified for input.numel() == 0.


 40%|███▉      | 2020/5067 [06:18<08:25,  6.02it/s]

argmax(): Expected reduction dim to be specified for input.numel() == 0.


 40%|███▉      | 2025/5067 [06:18<08:04,  6.28it/s]

argmax(): Expected reduction dim to be specified for input.numel() == 0.


 43%|████▎     | 2184/5067 [06:45<07:39,  6.27it/s]

argmax(): Expected reduction dim to be specified for input.numel() == 0.


 46%|████▌     | 2327/5067 [07:08<06:45,  6.76it/s]

argmax(): Expected reduction dim to be specified for input.numel() == 0.


 46%|████▌     | 2341/5067 [07:10<06:47,  6.69it/s]

argmax(): Expected reduction dim to be specified for input.numel() == 0.


 47%|████▋     | 2357/5067 [07:12<06:34,  6.87it/s]

argmax(): Expected reduction dim to be specified for input.numel() == 0.


 47%|████▋     | 2359/5067 [07:13<06:26,  7.00it/s]

argmax(): Expected reduction dim to be specified for input.numel() == 0.


 47%|████▋     | 2366/5067 [07:14<06:50,  6.59it/s]

argmax(): Expected reduction dim to be specified for input.numel() == 0.


 52%|█████▏    | 2653/5067 [07:57<05:46,  6.96it/s]

argmax(): Expected reduction dim to be specified for input.numel() == 0.


 58%|█████▊    | 2922/5067 [08:39<05:07,  6.99it/s]

argmax(): Expected reduction dim to be specified for input.numel() == 0.


 60%|██████    | 3065/5067 [09:01<05:00,  6.65it/s]

argmax(): Expected reduction dim to be specified for input.numel() == 0.


 70%|██████▉   | 3537/5067 [10:13<03:50,  6.63it/s]

argmax(): Expected reduction dim to be specified for input.numel() == 0.


 71%|███████   | 3584/5067 [10:20<03:50,  6.42it/s]

argmax(): Expected reduction dim to be specified for input.numel() == 0.


 84%|████████▍ | 4275/5067 [12:07<02:05,  6.32it/s]

argmax(): Expected reduction dim to be specified for input.numel() == 0.


 90%|████████▉ | 4536/5067 [12:47<01:22,  6.44it/s]

argmax(): Expected reduction dim to be specified for input.numel() == 0.


 92%|█████████▏| 4663/5067 [13:06<00:59,  6.76it/s]

argmax(): Expected reduction dim to be specified for input.numel() == 0.


100%|█████████▉| 5066/5067 [14:09<00:00,  6.79it/s]

argmax(): Expected reduction dim to be specified for input.numel() == 0.


100%|██████████| 5067/5067 [14:09<00:00,  5.97it/s]


#Face Embedding

In [225]:
from scipy.spatial.distance import euclidean
import math
from skimage.transform import rotate
from mtcnn import MTCNN
from facenet_pytorch import MTCNN as MTCNN2
import mediapipe
import numpy as np
import pandas as pd
import cv2
import os
from PIL import Image
import torch
from torchvision import transforms
import urllib


def get_model_path(model_name):
    model_file = model_name + ".pt"
    cache_dir = os.path.join(os.path.expanduser("~"), ".hsemotions")
    # cache_dir = "emotion_models"
    os.makedirs(cache_dir, exist_ok=True)
    fpath = os.path.join(cache_dir, model_file)
    if not os.path.isfile(fpath):
        print(f"{model_file} not exists")
        url = (
            "https://github.com/HSE-asavchenko/face-emotion-recognition/blob/main/models/affectnet_emotions/"
            + model_file
            + "?raw=true"
        )
        print("Downloading", model_name, "from", url)
        urllib.request.urlretrieve(url, fpath)

    return fpath


class FaceAlignment:
    def __init__(
        self,
    ):
        pass

    @staticmethod
    def apply_rotation_on_images(input_images, angles):
        rotated_images = [
            rotate(image, angle) for image, angle in zip(input_images, angles)
        ]
        return rotated_images

    @staticmethod
    def compute_alignment_rotation_(eyes_coordinates):
        angles = []
        directions = []
        for left_eye_coordinate, right_eye_coordinate in eyes_coordinates:

            left_eye_x, left_eye_y = left_eye_coordinate
            right_eye_x, right_eye_y = right_eye_coordinate

            triangle_vertex = (
                (right_eye_x, left_eye_y)
                if left_eye_y > right_eye_y
                else (left_eye_x, right_eye_y)
            )
            direction = (
                -1 if left_eye_y > right_eye_y else 1
            )  # rotate clockwise else counter-clockwise

            # compute length of triangle edges
            a = euclidean(left_eye_coordinate, triangle_vertex)
            b = euclidean(right_eye_coordinate, triangle_vertex)
            c = euclidean(right_eye_coordinate, left_eye_coordinate)

            # cosine rule
            if (
                b != 0 and c != 0
            ):  # this multiplication causes division by zero in cos_a calculation
                cos_a = (b**2 + c**2 - a**2) / (2 * b * c)
                angle = np.arccos(cos_a)  # angle in radian
                angle = (angle * 180) / math.pi  # radian to degree
            else:
                angle = 0

            angle = angle - 90 if direction == -1 else angle

            angles.append(angle)
            directions.append(direction)

        return angles, directions


class FaceDetection:

    # first call extract_face
    def __init__(self, model_name, minimum_confidence):

        self.detected_faces_information = None
        self.model_name = model_name
        self.minimum_confidence = minimum_confidence
        if model_name == "MTCNN":
            detector_model = MTCNN2(device=device)
            self.detect_faces_function = (
                lambda input_image: detector_model.detect(input_image, landmarks=True)
            )

    def extract_faces(self, input_image, return_detections_information=True):
        self.detect_faces__(input_image)
        faces = self.get_faces__(
            input_image,
        )
        if return_detections_information:
            return faces, self.detected_faces_information

        else:
            return faces

    def detect_faces__(self, input_image):
        detections = self.detect_faces_function(input_image)
        detections = [
            {
                'box': detections[0][i],
                'confidence': detections[1][i],
                'keypoints': {
                    'left_eye': detections[2][i][0],
                    'right_eye': detections[2][i][1],
                    'nose': detections[2][i][2],
                    'mouth_left': detections[2][i][3], 
                    'mouth_right': detections[2][i][4]
                }
              
            }
            for i in range(detections[0].shape[0])]
        self.detected_faces_information = list(
            filter(
                lambda element: element["confidence"] > self.minimum_confidence,
                detections,
            )
        )


    def get_detected_faces_information(self):
        return self.detected_faces_information

    def get_keypoints(
        self,
    ):
        return list(
            map(lambda element: element["keypoints"], self.detected_faces_information)
        )

    def get_faces__(
        self,
        input_image,
    ):
        boxes = [
            detection_information["box"]
            for detection_information in self.detected_faces_information
        ]
        y1y2x1x2 = [(int(y), int(y2), int(x), int(x2)) for x, y, x2, y2 in boxes]
        faces = [input_image[y1:y2, x1:x2] for y1, y2, x1, x2 in y1y2x1x2]
        return faces

    def get_eyes_coordinates(
        self,
    ):
        eyes_coordinates = [
            (info["keypoints"]["left_eye"], info["keypoints"]["right_eye"])
            for info in self.detected_faces_information
        ]
        return eyes_coordinates


class FaceEmotionRecognizer:
    # supported values of model_name: enet_b0_8_best_vgaf, enet_b0_8_best_afew, enet_b2_8, enet_b0_8_va_mtl, enet_b2_7
    def __init__(self, device, model_name="enet_b0_8_best_vgaf"):
        self.device = device
        self.is_mtl = "_mtl" in model_name
        if "_7" in model_name:
            self.idx_to_class = {
                0: "Anger",
                1: "Disgust",
                2: "Fear",
                3: "Happiness",
                4: "Neutral",
                5: "Sadness",
                6: "Surprise",
            }
        else:
            self.idx_to_class = {
                0: "Anger",
                1: "Contempt",
                2: "Disgust",
                3: "Fear",
                4: "Happiness",
                5: "Neutral",
                6: "Sadness",
                7: "Surprise",
            }

        self.img_size = 224 if "_b0_" in model_name else 260
        self.test_transforms = transforms.Compose(
            [
                transforms.Resize((self.img_size, self.img_size)),
                transforms.ToTensor(),
                transforms.Normalize(
                    mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
                ),
            ]
        )

        path = get_model_path(model_name)

        model = torch.load(path)
        model = model.to(device)

        if isinstance(model.classifier, torch.nn.Sequential):
            self.classifier_weights = model.classifier[0].weight.data
            self.classifier_bias = model.classifier[0].bias.data
        else:
            self.classifier_weights = model.classifier.weight.data
            self.classifier_bias = model.classifier.bias.data

        model.classifier = torch.nn.Identity()
        self.model = model.eval()
        # print(path, self.test_transforms)

    def compute_probability(self, features):
        return torch.matmul(features, self.classifier_weights.T) + self.classifier_bias

    def extract_representations_from_faces(self, input_faces):
        faces = [self.test_transforms(Image.fromarray(face)) for face in input_faces]
        features = self.model(torch.stack(faces, dim=0).to(self.device))
        return features

    def predict_emotions_from_representations(
        self, representations, logits=True, return_features=True
    ):
        scores = self.compute_probability(representations)
        if self.is_mtl:
            predictions_indices = torch.argmax(scores[:, :-2], dim=1)

        else:
            predictions_indices = torch.argmax(scores, dim=1)

        if self.is_mtl:
            x = scores[:, :-2]

        else:
            x = scores
        pred = torch.argmax(x[0])

        if not logits:
            e_x = torch.exp(x - torch.max(x, dim=1)[:, None])
            e_x = e_x / e_x.sum(dim=1)[:, None]
            if self.is_mtl:
                scores[:, :-2] = e_x
            else:
                scores = e_x

        return [
            self.idx_to_class[pred.item()] for pred in (predictions_indices)
        ], scores


class FaceNormalizer:
    def __init__(self):
        self.mp_face_mesh = mediapipe.solutions.face_mesh
        face_mesh = self.mp_face_mesh.FaceMesh(static_image_mode=True)

        mp_face_mesh = mediapipe.solutions.face_mesh
        self.face_mesh = mp_face_mesh.FaceMesh(static_image_mode=True)
        self.routes_idx = self.initialize__()

    def initialize__(self):
        df = pd.DataFrame(
            list(self.mp_face_mesh.FACEMESH_FACE_OVAL), columns=["p1", "p2"]
        )
        routes_idx = []

        p1 = df.iloc[0]["p1"]
        p2 = df.iloc[0]["p2"]

        for i in range(0, df.shape[0]):
            obj = df[df["p1"] == p2]
            p1 = obj["p1"].values[0]
            p2 = obj["p2"].values[0]

            route_idx = []
            route_idx.append(p1)
            route_idx.append(p2)
            routes_idx.append(route_idx)

        return routes_idx

    def get_landmarks__(self, input_image: np.ndarray):
        if input_image.dtype == np.float:
            input_image = (input_image * 255).astype(np.uint8)

        results = self.face_mesh.process(input_image)
        landmarks = results.multi_face_landmarks[0]

        routes = []
        # for source_idx, target_idx in mp_face_mesh.FACEMESH_FACE_OVAL:
        for source_idx, target_idx in self.routes_idx:
            source = landmarks.landmark[source_idx]
            target = landmarks.landmark[target_idx]

            relative_source = (
                int(input_image.shape[1] * source.x),
                int(input_image.shape[0] * source.y),
            )
            relative_target = (
                int(input_image.shape[1] * target.x),
                int(input_image.shape[0] * target.y),
            )

            # cv2.line(img, relative_source, relative_target, (255, 255, 255), thickness = 2)

            routes.append(relative_source)
            routes.append(relative_target)

        return routes

    @staticmethod
    def normalize_with_landmark_points__(input_image, landmarks):
        mask = np.zeros((input_image.shape[0], input_image.shape[1]))
        mask = cv2.fillConvexPoly(mask, np.array(landmarks), 1)
        mask = mask.astype(bool)

        out = np.zeros_like(input_image)
        out[mask] = input_image[mask]
        return out

    def normalize_faces_image(self, input_images):
        normalized_faces_images = [
            self.normalize_with_landmark_points__(
                input_image, self.get_landmarks__(input_image)
            )
            for input_image in input_images
        ]
        return normalized_faces_images


class FaceEmbeddingExtractor:
    def __init__(
        self,
        device='cuda'
    ):
        self.faces = None
        self.normalized_rotated_faces = None
        self.rotated_faces = None
        self.rotation_angles = None
        self.rotation_directions = None

        fd = FaceDetection("MTCNN", minimum_confidence=0.95)
        self.face_detection_model: FaceDetection = fd
        fa = FaceAlignment()
        self.face_alignment_model: FaceAlignment = fa
        fn = FaceNormalizer()
        self.face_normalizer_model: FaceNormalizer = fn
        model_name = "enet_b0_8_best_afew"
        fer = FaceEmotionRecognizer(device, model_name)
        self.face_emotion_recognition_model: FaceEmotionRecognizer = fer


    def extract_embedding(self, input_image):
        faces, detected_faces_information = self.face_detection_model.extract_faces(
            input_image, return_detections_information=True
        )

        (
            rotation_angles,
            rotation_directions,
        ) = self.face_alignment_model.compute_alignment_rotation_(
            self.face_detection_model.get_eyes_coordinates()
        )
        rotated_faces = self.face_alignment_model.apply_rotation_on_images(
            faces, rotation_angles
        )
        normalized_rotated_faces = self.face_normalizer_model.normalize_faces_image(
            rotated_faces
        )

        normalized_rotated_faces_255 = [
            (image * 255).astype(np.uint8) for image in normalized_rotated_faces
        ]

        representations = (
            self.face_emotion_recognition_model.extract_representations_from_faces(
                normalized_rotated_faces_255
            )
        )[0] #WARNING: 0 was not here
        del normalized_rotated_faces_255
        del normalized_rotated_faces
        del rotated_faces
        del rotation_angles
        del rotation_directions
        del faces
        del detected_faces_information
        # (
        #     predictions,
        #     scores,
        # ) = self.face_emotion_recognition_model.predict_emotions_from_representations(
        #     representations
        # )

        # self.faces = faces
        # self.rotation_angles, self.rotation_directions = (
        #     rotation_angles,
        #     rotation_directions,
        # )
        # self.rotated_faces = rotated_faces
        # self.normalized_rotated_faces = normalized_rotated_faces_255

        return None, None, representations
        # return preictions, scores, representations

    def get_rotations_information(self):
        return self.rotation_angles, self.rotation_directions

    def get_faces(self):
        return self.faces

    def get_rotated_faces(self):
        return self.rotated_faces

    def get_normalized_rotated_faces(self):
        return self.normalized_rotated_faces

    def clear(self):
        self.faces = None
        self.normalized_rotated_faces = None
        self.rotated_faces = None
        self.rotation_angles = None
        self.rotation_directions = None

    def store_embeddings(self, file, embeddings):
        with open(file, "wb") as file_out:
            pickle.dump(
                {"embeddings": embeddings}, file_out, protocol=pickle.HIGHEST_PROTOCOL
            )

    def load_embeddings(self, file):
        with open(file, "rb") as file_in:
            stored_data = pickle.load(file_in)
            stored_embeddings = stored_data["embeddings"]

        return stored_embeddings

#Text Embedding


In [226]:
from transformers import AutoTokenizer, AutoModel, pipeline
from transformers import RobertaForSequenceClassification
import torch
import pickle


class TextEmbeddingExtractor:
    def __init__(
        self,
        model_name="pysentimiento/robertuito-sentiment-analysis",
        show_progress_bar=True,
        to_tensor=True,
        max_length=128,
        device='cuda'
    ):
        self.model_name = model_name
        self.device = device
        self.max_length = max_length
        self.tokenizer = AutoTokenizer.from_pretrained(self.model_name)
        self.model = RobertaForSequenceClassification.from_pretrained(
            self.model_name, num_labels=3, output_hidden_states=True
        ).to(self.device)

        self.generator = pipeline(
            task="sentiment-analysis",
            model=self.model,
            tokenizer=self.tokenizer,
        )

    def extract_embedding(
        self,
        input_batch_sentences,
    ):
        encoded_input = self.tokenizer(
            input_batch_sentences,
            padding=True,
            truncation=True,
            max_length=self.max_length,
            return_tensors="pt",
        ).to(self.device)

        with torch.no_grad():
            model_output = self.model(**encoded_input)
            hidden_states = model_output["hidden_states"]
            last_layer_hidden_states = hidden_states[
                12
            ]  # 12 = len(hidden_states) , dim = (batch_size, seq_len, 768)
            cls_hidden_state = last_layer_hidden_states[:, 0, :]

        return cls_hidden_state

    def get_labels(self, input_batch_sentences):
        return self.generator(input_batch_sentences)


#Dataset

In [227]:
FACE_EMBEDDING_SIZE = 1280
TEXT_EMBEDDING_SIZE = 768
POSE_EMBEDDING_SIZE = 34
SCENE_EMBEDDING_SIZE = None


In [228]:
!free

              total        used        free      shared  buff/cache   available
Mem:       13298572     9226444      182716       25804     3889412     6754464
Swap:             0           0           0


In [229]:
import os, cv2, torch, ast
import pandas as pd
import numpy as np
from pathlib import Path
from torch.utils.data import Dataset
from tqdm import trange
from tqdm import tqdm


class MSCTDDataSet(Dataset):
    """MSCTD dataset."""

    def __init__(self, base_path="data/", split="train", data_size=None, load=False, raw=False):
        """
        Args:
            base_path (str or path): path to data folder
            split (str): dev, train, test
        """
        if isinstance(base_path, str):
            base_path = Path(base_path)
        self.base_path = base_path
        self.load_path = base_path / 'saved_features'
        self.split = split
        self.text_file_path = base_path / f"english_{split}.txt"
        self.seq_file_path = base_path / f"image_index_{split}.txt"
        self.sentiment_file_path = base_path / f"sentiment_{split}.txt"
        self.image_dir = base_path / "images" / split
        self.correct_indexes_file_path = base_path / "correct_indexes" / f"correct_indexes_{split}.txt"

        self.data_size = data_size
        self.load = load
        self.raw = raw

        self.texts = None
        self.sentiments = None
        self.indexes = None
        self.face_embeddings = None
        self.pose_embeddings = None
        self.text_embeddings = None
        self.load_data()
        self.face_embedding_extractor = FaceEmbeddingExtractor(device=device)
        self.text_embedding_extractor = TextEmbeddingExtractor(device=device)
        self.pose_embedding_extractor = PoseEmbeddingExtractor(device=device)



    def load_data(self):
        with open(self.text_file_path) as text_file, open(self.sentiment_file_path) as sentiment_file, open(self.correct_indexes_file_path) as correct_file:
            texts = [t.strip() for t in text_file.readlines()]
            sentiments = [int(t.strip()) for t in sentiment_file.readlines()]
            face_embeddings = None
            pose_embeddings = None
            text_embeddings = None
            corrects = [int(c.strip()) for c in correct_file.readlines()]
            if self.load:
                try:
                    face_embeddings = torch.load(self.save_path / f'face_embeddings_{self.split}.pt')
                    pose_embeddings = torch.load(self.save_path / f'pose_embeddings_{self.split}.pt')
                    text_embeddings = torch.load(self.save_path / f'text_embeddings_{self.split}.pt')
                    corrects = torch.load(self.save_path / f'real_indexes_{self.split}.pt')
                except Exception as e:
                    print(e)
                    print('Warning: passed load=True but not embedding file was located. Not loading')

            correct_texts = [texts[i] for i in corrects]
            correct_sentiments = [sentiments[i] for i in corrects]
        # with open(self.image_index_path) as f:
        #     images = [ast.literal_eval(t.strip()) for t in f.readlines()]

        if self.data_size:
            correct_texts = correct_texts[: self.data_size]
            correct_sentiments =correct_sentiments[: self.data_size]
            if face_embeddings:
                face_embeddings = face_embeddings[:self.data_size,:]
            if pose_embeddings:
                pose_embeddings = pose_embeddings[:self.data_size,:]
            if text_embeddings:
                face_embeddings = text_embeddings[:self.data_size,:]
            # images = images[: self.data_size]


        self.texts = correct_texts
        self.text_embeddings = text_embeddings
        self.sentiments = correct_sentiments
        self.indexes = corrects
        self.face_embeddings = face_embeddings
        self.pose_embeddings = pose_embeddings



    def __len__(self):
        return len(self.texts)

    def get_face_embedding(self, index, image):
        if self.load:
            return self.face_embeddings[index]
        (
            predictions,
            scores,
            representations,
        ) = self.face_embedding_extractor.extract_embedding(image)
        return representations

    def get_pose_embedding(self, index, image):
        if self.load:
            return self.pose_embeddings[index]
        return self.pose_embedding_extractor.extract_embedding(image)

    def get_image_embeddings(self, index):
        image = None
        real_index = self.indexes[index]
        image_name = self.image_dir / f"{real_index}.jpg"
        if not self.load:
            image = cv2.cvtColor(cv2.imread(str(image_name)), cv2.COLOR_BGR2RGB)

        face_embedding = self.get_face_embedding(index, image)
        pose_embedding = self.get_pose_embedding(index, image)
        return face_embedding, pose_embedding

    def get_sentiment(self, index):
        return self.sentiments[index]



    def get_text(self, index):
        if self.load and not self.text_embeddings is None:
            print('loading txt')
            return self.text_embeddings[index]
        text = self.texts[index]
        text = self.text_embedding_extractor.extract_embedding([text])[0]
        return text

    def __getitem__(self, index):
        if torch.is_tensor(index):
            index = index.tolist()
        try:
            face_embedding, pose_embedding = self.get_image_embeddings(index)
        except Exception as e:
            print(f'error for split:{self.split} index: {index}')
            print(e)
            face_embedding = torch.ones(FACE_EMBEDDING_SIZE).to(device)*-123
            pose_embedding = torch.ones(POSE_EMBEDDING_SIZE).to(device)*-123

        sentiment = self.get_sentiment(index)
        text = self.get_text(index)
        sample = {"real_index": self.indexes[index], "pose_embedding": pose_embedding, "face_embedding": face_embedding, "text_embedding": text, "sentiment": sentiment}
        if self.raw:
            sample["text"] = self.texts[index]
            # sample["image"] = pass 

        return sample


# Save features

In [230]:
# %mkdir data/saved_features/
# %mkdir backups/

In [231]:
from torch.utils.data import DataLoader

SAVE_SPLIT = "test"
SAVE_BATCH = 8
dataset = MSCTDDataSet(base_path=project_path / "data/", split = SAVE_SPLIT, load=False)
print(len(dataset))
dataloader = DataLoader(dataset, batch_size=SAVE_BATCH)

3478


In [232]:
def save_features(dataloader, split):
    save_path = project_path / 'data' / 'saved_features'
    stop_batch = None

    for batch_index, batch in enumerate(tqdm(dataloader)):
        # print(batch["face_embedding"].shape)
        # print(batch["text_embedding"].shape)
        # print(batch["real_index"].shape)
        errors = (batch["pose_embedding"]==-123).all(dim=1)

        torch.save(batch["face_embedding"][~errors], save_path / f'face_embeddings_{split}_{batch_index}.pt')
        torch.save(batch["pose_embedding"][~errors], save_path / f'pose_embeddings_{split}_{batch_index}.pt')
        torch.save(batch["text_embedding"][~errors], save_path / f'text_embeddings_{split}_{batch_index}.pt')
        torch.save(batch["real_index"][~errors], save_path / f'real_indexes_{split}_{batch_index}.pt')
        if stop_batch and batch_index==stop_batch:
          break


    print('----------------------')
    print(len(dataloader))
    len_batch = len(dataloader)
    if stop_batch:
        len_batch = stop_batch
    face_embeddings = []
    for i in range(len_batch):
        face_embeddings.append(torch.load(save_path / f'face_embeddings_{split}_{i}.pt'))
    face_embeddings = torch.cat(face_embeddings, dim=0)
    print(face_embeddings.shape)
    torch.save(face_embeddings, save_path / f'face_embeddings_{split}.pt')
    del face_embeddings

    pose_embeddings = []
    for i in range(len_batch):
        pose_embeddings.append(torch.load(save_path / f'pose_embeddings_{split}_{i}.pt'))
    pose_embeddings = torch.cat(pose_embeddings, dim=0)
    print(pose_embeddings.shape)
    torch.save(pose_embeddings, save_path / f'pose_embeddings_{split}.pt')
    del pose_embeddings

    text_embeddings = []
    for i in range(len_batch):
        text_embeddings.append(torch.load(save_path / f'text_embeddings_{split}_{i}.pt'))
    text_embeddings = torch.cat(text_embeddings, dim=0)
    print(text_embeddings.shape)
    torch.save(text_embeddings, save_path / f'text_embeddings_{split}.pt')
    del text_embeddings

    real_indexes = []
    for i in range(len_batch):
        real_indexes.append(torch.load(save_path / f'real_indexes_{split}_{i}.pt'))
    real_indexes = torch.cat(real_indexes, dim=0)
    print(real_indexes.shape)
    torch.save(real_indexes, save_path / f'real_indexes_{split}.pt')
    del real_indexes



In [None]:
save_features(dataloader, SAVE_SPLIT)

Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  0%|          | 1/435 [00:01<13:30,  1.87s/it]

error for split:test index: 9
'NoneType' object is not subscriptable


  0%|          | 2/435 [00:03<12:24,  1.72s/it]

error for split:test index: 15
('Cannot warp empty image with dimensions', (0, 317, 3))


  1%|          | 3/435 [00:05<12:16,  1.70s/it]

error for split:test index: 22
('Cannot warp empty image with dimensions', (0, 250, 3))
error for split:test index: 23
('Cannot warp empty image with dimensions', (0, 250, 3))


  3%|▎         | 13/435 [00:23<12:46,  1.82s/it]

error for split:test index: 105
('Cannot warp empty image with dimensions', (0, 259, 3))


  4%|▎         | 16/435 [00:28<12:31,  1.79s/it]

error for split:test index: 132
'NoneType' object is not subscriptable


  5%|▍         | 20/435 [00:36<14:31,  2.10s/it]

error for split:test index: 163
('Cannot warp empty image with dimensions', (0, 190, 3))


  6%|▌         | 24/435 [00:46<16:00,  2.34s/it]

error for split:test index: 196
'NoneType' object is not subscriptable


  6%|▌         | 25/435 [00:48<14:53,  2.18s/it]

error for split:test index: 200
'NoneType' object is not subscriptable
error for split:test index: 206
'NoneType' object is not subscriptable


  9%|▊         | 37/435 [01:16<16:01,  2.42s/it]

error for split:test index: 296
stack expects a non-empty TensorList


  9%|▉         | 39/435 [01:21<16:34,  2.51s/it]

error for split:test index: 317
'NoneType' object is not subscriptable


  9%|▉         | 40/435 [01:23<16:15,  2.47s/it]

error for split:test index: 320
stack expects a non-empty TensorList


 11%|█         | 48/435 [01:43<16:06,  2.50s/it]

error for split:test index: 383
('Cannot warp empty image with dimensions', (112, 0, 3))
error for split:test index: 388
('Cannot warp empty image with dimensions', (203, 0, 3))


 11%|█▏        | 49/435 [01:45<14:36,  2.27s/it]

error for split:test index: 396
('Cannot warp empty image with dimensions', (0, 196, 3))


 12%|█▏        | 52/435 [01:51<12:56,  2.03s/it]

error for split:test index: 421
('Cannot warp empty image with dimensions', (0, 349, 3))


 12%|█▏        | 53/435 [01:52<12:19,  1.94s/it]

error for split:test index: 425
('Cannot warp empty image with dimensions', (0, 310, 3))
error for split:test index: 430
('Cannot warp empty image with dimensions', (58, 0, 3))


 13%|█▎        | 55/435 [01:56<12:37,  1.99s/it]

error for split:test index: 438
('Cannot warp empty image with dimensions', (0, 286, 3))
error for split:test index: 439
'NoneType' object is not subscriptable
error for split:test index: 441
stack expects a non-empty TensorList


 13%|█▎        | 57/435 [02:01<13:24,  2.13s/it]

error for split:test index: 462
('Cannot warp empty image with dimensions', (0, 366, 3))


 14%|█▎        | 59/435 [02:05<13:44,  2.19s/it]

error for split:test index: 471
'NoneType' object is not subscriptable
error for split:test index: 474
'NoneType' object is not subscriptable
error for split:test index: 475
'NoneType' object is not subscriptable
error for split:test index: 476
'NoneType' object is not subscriptable


 14%|█▍        | 60/435 [02:07<11:57,  1.91s/it]

error for split:test index: 479
'NoneType' object is not subscriptable


 16%|█▌        | 69/435 [02:26<12:17,  2.02s/it]

error for split:test index: 555
('Cannot warp empty image with dimensions', (0, 462, 3))
error for split:test index: 558
'NoneType' object is not subscriptable


 17%|█▋        | 74/435 [02:38<14:57,  2.49s/it]

error for split:test index: 594
'NoneType' object is not subscriptable
error for split:test index: 597
stack expects a non-empty TensorList


 17%|█▋        | 76/435 [02:43<14:37,  2.44s/it]

error for split:test index: 612
stack expects a non-empty TensorList


 18%|█▊        | 80/435 [02:54<15:42,  2.66s/it]

error for split:test index: 644
'NoneType' object is not subscriptable


 23%|██▎       | 98/435 [03:33<12:23,  2.20s/it]

error for split:test index: 789
stack expects a non-empty TensorList


 23%|██▎       | 100/435 [03:40<15:45,  2.82s/it]

error for split:test index: 801
'NoneType' object is not subscriptable


 23%|██▎       | 101/435 [03:42<15:22,  2.76s/it]

error for split:test index: 813
'NoneType' object is not subscriptable


 23%|██▎       | 102/435 [03:45<15:26,  2.78s/it]

error for split:test index: 816
('Cannot warp empty image with dimensions', (114, 0, 3))
error for split:test index: 817
('Cannot warp empty image with dimensions', (115, 0, 3))
error for split:test index: 818
('Cannot warp empty image with dimensions', (129, 0, 3))


 24%|██▎       | 103/435 [03:48<14:59,  2.71s/it]

error for split:test index: 828
'NoneType' object is not subscriptable


 24%|██▍       | 104/435 [03:50<14:31,  2.63s/it]

error for split:test index: 838
'NoneType' object is not subscriptable


 25%|██▍       | 107/435 [03:56<12:14,  2.24s/it]

error for split:test index: 859
'NoneType' object is not subscriptable


 26%|██▌       | 111/435 [04:05<12:18,  2.28s/it]

error for split:test index: 891
'NoneType' object is not subscriptable


 26%|██▋       | 115/435 [04:14<11:42,  2.19s/it]

error for split:test index: 920
'NoneType' object is not subscriptable
error for split:test index: 923
'NoneType' object is not subscriptable


 28%|██▊       | 120/435 [04:24<11:13,  2.14s/it]

error for split:test index: 964
'NoneType' object is not subscriptable


 28%|██▊       | 121/435 [04:26<10:47,  2.06s/it]

error for split:test index: 974
('Cannot warp empty image with dimensions', (0, 295, 3))


 29%|██▊       | 125/435 [04:33<09:46,  1.89s/it]

error for split:test index: 1003
'NoneType' object is not subscriptable


 29%|██▉       | 126/435 [04:35<09:44,  1.89s/it]

error for split:test index: 1010
('Cannot warp empty image with dimensions', (0, 166, 3))
error for split:test index: 1011
('Cannot warp empty image with dimensions', (0, 0, 3))


 30%|██▉       | 129/435 [04:42<10:58,  2.15s/it]

error for split:test index: 1035
('Cannot warp empty image with dimensions', (0, 160, 3))
error for split:test index: 1036
'NoneType' object is not subscriptable


 30%|███       | 131/435 [04:46<10:27,  2.06s/it]

error for split:test index: 1047
'NoneType' object is not subscriptable


 30%|███       | 132/435 [04:48<10:22,  2.06s/it]

error for split:test index: 1056
'NoneType' object is not subscriptable


 31%|███       | 134/435 [04:52<09:51,  1.97s/it]

error for split:test index: 1070
'NoneType' object is not subscriptable
error for split:test index: 1071
('Cannot warp empty image with dimensions', (423, 0, 3))


 31%|███       | 135/435 [04:54<09:47,  1.96s/it]

error for split:test index: 1083
('Cannot warp empty image with dimensions', (0, 355, 3))
error for split:test index: 1084
('Cannot warp empty image with dimensions', (0, 331, 3))


 31%|███▏      | 136/435 [04:56<09:47,  1.97s/it]

error for split:test index: 1094
('Cannot warp empty image with dimensions', (114, 0, 3))


 31%|███▏      | 137/435 [04:58<09:53,  1.99s/it]

error for split:test index: 1097
('Cannot warp empty image with dimensions', (133, 0, 3))


 32%|███▏      | 138/435 [05:00<11:08,  2.25s/it]

error for split:test index: 1108
('Cannot warp empty image with dimensions', (149, 0, 3))
error for split:test index: 1109
'NoneType' object is not subscriptable


 33%|███▎      | 143/435 [05:11<10:14,  2.10s/it]

error for split:test index: 1145
'NoneType' object is not subscriptable


 33%|███▎      | 145/435 [05:15<09:33,  1.98s/it]

error for split:test index: 1159
('Cannot warp empty image with dimensions', (0, 185, 3))


 34%|███▍      | 150/435 [05:26<10:22,  2.18s/it]

error for split:test index: 1202
('Cannot warp empty image with dimensions', (0, 232, 3))


 36%|███▌      | 155/435 [05:35<09:14,  1.98s/it]

error for split:test index: 1240
('Cannot warp empty image with dimensions', (0, 288, 3))


 37%|███▋      | 159/435 [05:42<08:39,  1.88s/it]

error for split:test index: 1274
'NoneType' object is not subscriptable
error for split:test index: 1278
'NoneType' object is not subscriptable


 37%|███▋      | 161/435 [05:47<09:08,  2.00s/it]

error for split:test index: 1293
'NoneType' object is not subscriptable


 37%|███▋      | 162/435 [05:49<09:38,  2.12s/it]

error for split:test index: 1300
'NoneType' object is not subscriptable
error for split:test index: 1301
'NoneType' object is not subscriptable
error for split:test index: 1302
'NoneType' object is not subscriptable


 37%|███▋      | 163/435 [05:51<08:58,  1.98s/it]

error for split:test index: 1304
'NoneType' object is not subscriptable


 39%|███▉      | 169/435 [06:03<09:22,  2.11s/it]

error for split:test index: 1353
'NoneType' object is not subscriptable


 39%|███▉      | 171/435 [06:08<10:33,  2.40s/it]

error for split:test index: 1371
'NoneType' object is not subscriptable


 40%|███▉      | 172/435 [06:11<10:33,  2.41s/it]

error for split:test index: 1378
'NoneType' object is not subscriptable


 40%|████      | 176/435 [06:21<10:16,  2.38s/it]

error for split:test index: 1408
('Cannot warp empty image with dimensions', (270, 0, 3))


 41%|████▏     | 180/435 [06:29<09:12,  2.17s/it]

error for split:test index: 1446
('Cannot warp empty image with dimensions', (276, 0, 3))


 44%|████▍     | 192/435 [06:53<08:04,  1.99s/it]

error for split:test index: 1538
('Cannot warp empty image with dimensions', (0, 295, 3))


 45%|████▍     | 194/435 [06:57<07:47,  1.94s/it]

error for split:test index: 1552
('Cannot warp empty image with dimensions', (0, 434, 3))
error for split:test index: 1556
('Cannot warp empty image with dimensions', (0, 0, 3))
error for split:test index: 1557
'NoneType' object is not subscriptable


 45%|████▍     | 195/435 [06:58<07:13,  1.81s/it]

error for split:test index: 1561
('Cannot warp empty image with dimensions', (0, 375, 3))
error for split:test index: 1563
('Cannot warp empty image with dimensions', (0, 394, 3))
error for split:test index: 1566
'NoneType' object is not subscriptable


 45%|████▌     | 196/435 [07:00<06:48,  1.71s/it]

error for split:test index: 1569
('Cannot warp empty image with dimensions', (0, 322, 3))
error for split:test index: 1570
('Cannot warp empty image with dimensions', (0, 347, 3))
error for split:test index: 1572
('Cannot warp empty image with dimensions', (0, 341, 3))
error for split:test index: 1574
'NoneType' object is not subscriptable


 45%|████▌     | 197/435 [07:01<06:18,  1.59s/it]

error for split:test index: 1577
('Cannot warp empty image with dimensions', (0, 426, 3))
error for split:test index: 1580
('Cannot warp empty image with dimensions', (0, 425, 3))


 46%|████▌     | 198/435 [07:03<06:24,  1.62s/it]

error for split:test index: 1584
('Cannot warp empty image with dimensions', (0, 387, 3))
error for split:test index: 1586
('Cannot warp empty image with dimensions', (0, 370, 3))
error for split:test index: 1588
('Cannot warp empty image with dimensions', (0, 424, 3))
error for split:test index: 1590
('Cannot warp empty image with dimensions', (0, 403, 3))


 46%|████▌     | 199/435 [07:04<06:06,  1.55s/it]

error for split:test index: 1593
'NoneType' object is not subscriptable


 46%|████▌     | 200/435 [07:06<06:13,  1.59s/it]

error for split:test index: 1599
('Cannot warp empty image with dimensions', (0, 482, 3))
error for split:test index: 1603
'NoneType' object is not subscriptable


 48%|████▊     | 208/435 [07:22<07:06,  1.88s/it]

error for split:test index: 1666
'NoneType' object is not subscriptable


 49%|████▊     | 211/435 [07:27<07:02,  1.89s/it]

error for split:test index: 1689
'NoneType' object is not subscriptable


 49%|████▊     | 212/435 [07:29<06:51,  1.85s/it]

error for split:test index: 1699
'NoneType' object is not subscriptable


 50%|█████     | 218/435 [07:41<07:11,  1.99s/it]

error for split:test index: 1744
('Cannot warp empty image with dimensions', (87, 0, 3))


 51%|█████▏    | 224/435 [07:54<07:51,  2.23s/it]

error for split:test index: 1796
'NoneType' object is not subscriptable


 52%|█████▏    | 225/435 [07:56<08:01,  2.29s/it]

error for split:test index: 1802
'NoneType' object is not subscriptable


 52%|█████▏    | 226/435 [07:59<08:05,  2.32s/it]

error for split:test index: 1812
('Cannot warp empty image with dimensions', (84, 0, 3))
error for split:test index: 1813
('Cannot warp empty image with dimensions', (0, 258, 3))


 52%|█████▏    | 228/435 [08:02<07:13,  2.09s/it]

error for split:test index: 1826
('Cannot warp empty image with dimensions', (0, 293, 3))


 53%|█████▎    | 230/435 [08:06<06:46,  1.98s/it]

error for split:test index: 1840
('Cannot warp empty image with dimensions', (0, 0, 3))


 53%|█████▎    | 231/435 [08:08<06:32,  1.92s/it]

error for split:test index: 1852
'NoneType' object is not subscriptable
error for split:test index: 1853
'NoneType' object is not subscriptable


 53%|█████▎    | 232/435 [08:09<05:58,  1.76s/it]

error for split:test index: 1855
'NoneType' object is not subscriptable
error for split:test index: 1856
'NoneType' object is not subscriptable


 54%|█████▎    | 233/435 [08:11<05:54,  1.75s/it]

error for split:test index: 1864
('Cannot warp empty image with dimensions', (0, 174, 3))


 55%|█████▍    | 238/435 [08:21<06:49,  2.08s/it]

error for split:test index: 1904
'NoneType' object is not subscriptable
error for split:test index: 1907
stack expects a non-empty TensorList


 55%|█████▍    | 239/435 [08:23<06:13,  1.90s/it]

error for split:test index: 1913
'NoneType' object is not subscriptable


 56%|█████▌    | 242/435 [08:28<05:41,  1.77s/it]

error for split:test index: 1940
('Cannot warp empty image with dimensions', (0, 134, 3))


 56%|█████▌    | 243/435 [08:29<05:26,  1.70s/it]

error for split:test index: 1946
('Cannot warp empty image with dimensions', (0, 172, 3))
error for split:test index: 1947
('Cannot warp empty image with dimensions', (0, 169, 3))


 56%|█████▌    | 244/435 [08:31<05:07,  1.61s/it]

error for split:test index: 1956
('Cannot warp empty image with dimensions', (0, 214, 3))


 57%|█████▋    | 250/435 [08:42<05:59,  1.94s/it]

error for split:test index: 2000
('Cannot warp empty image with dimensions', (225, 0, 3))


 58%|█████▊    | 252/435 [08:46<05:58,  1.96s/it]

error for split:test index: 2016
'NoneType' object is not subscriptable


 59%|█████▉    | 257/435 [08:56<05:56,  2.00s/it]

error for split:test index: 2059
'NoneType' object is not subscriptable


 60%|█████▉    | 259/435 [09:00<05:45,  1.96s/it]

error for split:test index: 2071
('Cannot warp empty image with dimensions', (34, 0, 3))
error for split:test index: 2072
('Cannot warp empty image with dimensions', (33, 0, 3))


 60%|█████▉    | 260/435 [09:02<05:40,  1.95s/it]

error for split:test index: 2080
('Cannot warp empty image with dimensions', (0, 150, 3))


 62%|██████▏   | 270/435 [09:21<05:06,  1.85s/it]

error for split:test index: 2163
('Cannot warp empty image with dimensions', (64, 0, 3))


 64%|██████▍   | 279/435 [09:38<04:57,  1.91s/it]

error for split:test index: 2235
'NoneType' object is not subscriptable
error for split:test index: 2236
'NoneType' object is not subscriptable


 64%|██████▍   | 280/435 [09:40<04:49,  1.87s/it]

error for split:test index: 2240
('Cannot warp empty image with dimensions', (0, 62, 3))
error for split:test index: 2243
('Cannot warp empty image with dimensions', (0, 57, 3))


 65%|██████▍   | 281/435 [09:41<04:42,  1.84s/it]

error for split:test index: 2252
('Cannot warp empty image with dimensions', (247, 0, 3))


 65%|██████▍   | 282/435 [09:43<04:40,  1.83s/it]

error for split:test index: 2258
'NoneType' object is not subscriptable


 65%|██████▌   | 284/435 [09:48<05:01,  2.00s/it]

error for split:test index: 2273
'NoneType' object is not subscriptable
error for split:test index: 2277
'NoneType' object is not subscriptable


 66%|██████▌   | 287/435 [09:54<05:01,  2.04s/it]

error for split:test index: 2300
'NoneType' object is not subscriptable
error for split:test index: 2301
'NoneType' object is not subscriptable


 66%|██████▌   | 288/435 [09:55<04:55,  2.01s/it]

error for split:test index: 2308
'NoneType' object is not subscriptable


 67%|██████▋   | 292/435 [10:04<05:11,  2.18s/it]

error for split:test index: 2335
('Cannot warp empty image with dimensions', (0, 28, 3))
error for split:test index: 2339
('Cannot warp empty image with dimensions', (0, 286, 3))


 67%|██████▋   | 293/435 [10:06<04:49,  2.04s/it]

error for split:test index: 2342
('Cannot warp empty image with dimensions', (0, 406, 3))
error for split:test index: 2343
('Cannot warp empty image with dimensions', (0, 431, 3))
error for split:test index: 2346
'NoneType' object is not subscriptable
error for split:test index: 2350
'NoneType' object is not subscriptable


 68%|██████▊   | 294/435 [10:08<04:38,  1.98s/it]

error for split:test index: 2353
'NoneType' object is not subscriptable


 68%|██████▊   | 295/435 [10:10<04:53,  2.10s/it]

error for split:test index: 2360
'NoneType' object is not subscriptable


 70%|███████   | 305/435 [10:31<04:28,  2.07s/it]

error for split:test index: 2442
'NoneType' object is not subscriptable


 71%|███████   | 308/435 [10:37<04:17,  2.03s/it]

error for split:test index: 2469
'NoneType' object is not subscriptable


 71%|███████   | 309/435 [10:39<04:00,  1.91s/it]

error for split:test index: 2471
'NoneType' object is not subscriptable


 73%|███████▎  | 316/435 [10:53<04:03,  2.05s/it]

error for split:test index: 2531
('Cannot warp empty image with dimensions', (245, 0, 3))


 73%|███████▎  | 317/435 [10:55<03:48,  1.94s/it]

error for split:test index: 2539
('Cannot warp empty image with dimensions', (0, 364, 3))


 73%|███████▎  | 318/435 [10:56<03:29,  1.79s/it]

error for split:test index: 2542
'NoneType' object is not subscriptable
error for split:test index: 2543
'NoneType' object is not subscriptable


 73%|███████▎  | 319/435 [10:58<03:28,  1.80s/it]

error for split:test index: 2554
'NoneType' object is not subscriptable


 74%|███████▍  | 322/435 [11:03<03:20,  1.78s/it]

error for split:test index: 2574
('Cannot warp empty image with dimensions', (0, 298, 3))
error for split:test index: 2575
('Cannot warp empty image with dimensions', (0, 367, 3))
error for split:test index: 2577
('Cannot warp empty image with dimensions', (0, 413, 3))
error for split:test index: 2578
'NoneType' object is not subscriptable
error for split:test index: 2579
'NoneType' object is not subscriptable
error for split:test index: 2580
'NoneType' object is not subscriptable
error for split:test index: 2581
('Cannot warp empty image with dimensions', (0, 474, 3))
error for split:test index: 2582
('Cannot warp empty image with dimensions', (0, 402, 3))


 74%|███████▍  | 323/435 [11:04<02:45,  1.47s/it]

error for split:test index: 2583
'NoneType' object is not subscriptable
error for split:test index: 2586
('Cannot warp empty image with dimensions', (0, 299, 3))
error for split:test index: 2587
('Cannot warp empty image with dimensions', (0, 290, 3))
error for split:test index: 2588
'NoneType' object is not subscriptable


 74%|███████▍  | 324/435 [11:06<02:47,  1.51s/it]

error for split:test index: 2595
'NoneType' object is not subscriptable


 75%|███████▌  | 327/435 [11:12<03:29,  1.94s/it]

error for split:test index: 2618
'NoneType' object is not subscriptable
error for split:test index: 2621
'NoneType' object is not subscriptable


 75%|███████▌  | 328/435 [11:14<03:26,  1.93s/it]

error for split:test index: 2629
'NoneType' object is not subscriptable


 76%|███████▌  | 330/435 [11:18<03:30,  2.00s/it]

error for split:test index: 2642
('Cannot warp empty image with dimensions', (0, 0, 3))


 76%|███████▌  | 331/435 [11:20<03:23,  1.95s/it]

error for split:test index: 2647
('Cannot warp empty image with dimensions', (0, 112, 3))
error for split:test index: 2652
stack expects a non-empty TensorList
error for split:test index: 2653
('Cannot warp empty image with dimensions', (110, 0, 3))


 78%|███████▊  | 339/435 [11:36<03:18,  2.07s/it]

error for split:test index: 2712
'NoneType' object is not subscriptable


 80%|███████▉  | 346/435 [11:51<03:09,  2.13s/it]

error for split:test index: 2769
'NoneType' object is not subscriptable


 80%|████████  | 349/435 [11:57<02:51,  2.00s/it]

error for split:test index: 2798
stack expects a non-empty TensorList


 81%|████████  | 351/435 [12:01<02:51,  2.05s/it]

error for split:test index: 2809
'NoneType' object is not subscriptable
error for split:test index: 2813
('Cannot warp empty image with dimensions', (0, 194, 3))


 81%|████████  | 352/435 [12:03<02:42,  1.96s/it]

error for split:test index: 2816
stack expects a non-empty TensorList
error for split:test index: 2817
stack expects a non-empty TensorList


 82%|████████▏ | 355/435 [12:08<02:35,  1.94s/it]

error for split:test index: 2844
'NoneType' object is not subscriptable


 83%|████████▎ | 360/435 [12:18<02:26,  1.96s/it]

error for split:test index: 2885
stack expects a non-empty TensorList


 84%|████████▎ | 364/435 [12:25<02:11,  1.85s/it]

error for split:test index: 2911
'NoneType' object is not subscriptable


 84%|████████▍ | 365/435 [12:28<02:13,  1.90s/it]

error for split:test index: 2926
'NoneType' object is not subscriptable


 84%|████████▍ | 366/435 [12:29<02:11,  1.90s/it]

error for split:test index: 2934
'NoneType' object is not subscriptable


 85%|████████▍ | 369/435 [12:36<02:17,  2.08s/it]

In [None]:
# CHANGE VAL TO SPLIT
!cp data/saved_features/face_embeddings_val.pt backup
!cp data/saved_features/pose_embeddings_val.pt backup
!cp data/saved_features/real_indexes_val.pt backup
!cp data/saved_features/text_embeddings_val.pt backup

In [45]:
!ls -sh backup

ls: cannot access 'backup': No such file or directory


In [None]:
%ls data/saved_features

In [None]:
!du data/saved_features/face_embeddings_val.pt -h

17M	data/saved_features/face_embeddings_val.pt


In [None]:
# save_path = project_path / 'data' / 'saved_features'
# split = "val"
# face_embeddings = []
# for i in range(3):
#     face_embeddings.append(torch.load(save_path / f'face_embeddings_{split}_{i}.pt'))
# face_embeddings = torch.cat(face_embeddings, dim=0)
# print(face_embeddings.shape)
# torch.save(face_embeddings, save_path / f'face_embeddings_{split}_{batch_index}.pt')
# del face_embeddings

In [None]:
del dataset
del dataloader

#Data Loader

In [None]:
class MSCTDDataLoader:
    def __init__(self, dl, device):
        self.dl = dl
        self.device = device

    def __iter__(self):
        for b in self.dl:
            yield to_device(b, self.device)

    def __len__(self):
        return len(self.dl)

def to_device(data, device):
    if isinstance(data, (list, tuple)):
        return [to_device(x, device) for x in data]
    if isinstance(data, dict):
        return {k: to_device(v, device) for k, v in data.items()}
    if isinstance(data, str):
        return data
    return data.to(device)

# ds = MSCTDDataSet(base_path=project_path + "data/", dataset_type = "val", load=True)
# dl = DataLoader(ds, batch_size=10)
# dl = MSCTDDataLoader(dl, device)
# for x in dl:
#   print(x)
#   print(x['face_embedding'].shape)
#   print(x['text_embedding'].shape)
#   print(x['real_index'])
#   break

In [None]:
import torch
from torch import nn

class SimpleDenseNetwork(nn.Module):
    def __init__(self, n_classes, embedding_dimension):
        super(SimpleDenseNetwork, self).__init__()

        self.n_classes = n_classes
        self.embedding_dimension = embedding_dimension

        self.fc = nn.Sequential(
            nn.Linear(
                in_features=self.embedding_dimension,
                out_features=512,
            ),
            nn.ReLU(inplace=True),
            nn.Linear(in_features=512, out_features=128),
            nn.ReLU(inplace=True),
            nn.Linear(in_features=128, out_features=3),
            nn.ReLU(inplace=True),
            nn.Softmax(),
        )

    def forward(self, input_batch):
        x = input_batch
        x = self.fc(x)
        output_batch = x

        return output_batch

#Train

In [None]:
BATCH_SIZE = 32
num_workers = 1
EPOCHS = 1
# embedding_dimension = 2048 + 34
embedding_dimension = FACE_EMBEDDING_SIZE + TEXT_EMBEDDING_SIZE + POSE_EMBEDDING_SIZE # + SCENE_EMBEDDING_SIZE

learning_rate = 0.01
momentum = 0.9
data_size = 1000

In [None]:
import torch.optim as optim
from datetime import datetime


def train_epoch(epoch_index, model, dataloader, loss_fn, optimizer):
    running_loss = 0.0
    # last_loss = 0.0

    for data_pair_index, batch in enumerate(dataloader):
        print("--------------", data_pair_index, "-------------")
        errors = (batch["pose_embeding"]==-123).all(dim=1)
        text_embedding = batch["text_embedding"][~errors]
        face_embedding = batch["face_embedding"][[~errors]]
        pose_embeding = batch["pose_embeding"][[~errors]]

        labels = batch["sentiment"]
        optimizer.zero_grad()

        outputs = model(torch.cat((face_embedding, text_embedding, pose_embeding), 1))

        loss = loss_fn(outputs, labels)
        loss.backward()

        optimizer.step()

        # Gather data and report
        running_loss += loss.item()
        # if data_pair_index % 1000 == 999:
        #     last_loss = running_loss / 1000  # loss per batch
        #     print("  batch {} loss: {}".format(data_pair_index + 1, last_loss))
        #     tb_x = epoch_index * len(dataloader) + data_pair_index + 1
        #     print("Loss/train", last_loss, tb_x)
        #     running_loss = 0.0
    print('Batch loss: ', running_loss)
    # return last_loss


def train_model(model, epochs, train_dataloader, val_dataloader):
    loss_fn = nn.CrossEntropyLoss()
    optimizer = optim.SGD(model.parameters(), lr=learning_rate, momentum=momentum)
    for epoch in trange(epochs):
        model.train()
        train_epoch(epoch, model, train_dataloader, loss_fn, optimizer)
        model.eval()
        validate(model, val_dataloader, loss_fn)

    return model

In [None]:
model = SimpleDenseNetwork(n_classes=3, embedding_dimension=embedding_dimension).to(device=device)

In [None]:
val_dataset = MSCTDDataSet(project_path + "data/", "val", data_size=data_size, load=True)
val_dataloader = DataLoader(val_dataset, batch_size=BATCH_SIZE)
# val_dataloader = MSCTDDataLoader(val_dataloader, device)

test_dataset = MSCTDDataSet(project_path + "data/", "test", load=True)
test_dataloader = DataLoader(test_dataset, batch_size=BATCH_SIZE)
# test_dataloader = MSCTDDataLoader(test_dataloader, device)

In [None]:
model = train_model(model, EPOCHS, val_dataloader, test_dataloader)

Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations


-------------- 0 -------------


  input = module(input)


-------------- 1 -------------
-------------- 2 -------------
-------------- 3 -------------
-------------- 4 -------------


#Evaluating

In [None]:
import evaluate

accuracy = evaluate.load("accuracy")
precision = evaluate.load("precision")

def validate(model, dataloader, loss_fn):
    running_loss = 0.0
    last_loss = 0.0

    for data_pair_index, batch in enumerate(dataloader):
        print("--------------", data_pair_index, "-------------")
        text_embedding = batch["text_embedding"]
        face_embedding = batch["face"]
        # pose_embeding = batch["pose_embeding"]
        labels = batch["sentiment"]

        logits = model(torch.cat((text_embedding, face_embedding), 1))
        # print(outputs)
        accuracy.add_batch(predictions=logits.argmax(dim=1), references=labels)
        precision.add_batch(predictions=logits.argmax(dim=1), references=labels)
        loss = loss_fn(logits, labels)
        running_loss += loss.item()
        # print(running_loss)
        # print('true answer',labels)
        # print('prediction',logits.argmax(dim=1))
        # if data_pair_index==2:
        #   break
    print(accuracy.compute())
    print(precision.compute(average=None))

In [None]:
validate(model, test_dataloader, nn.CrossEntropyLoss())