In [None]:
!pip install scipy scikit-image torch tqdm transformers mtcnn mediapipe opencv-python torchvision numpy pandas timm evaluate gluoncv mxnet 

In [2]:
from google.colab import drive
from pathlib import Path
drive.mount("/content/drive")
project_path = Path("/content/drive/MyDrive/NLP/MultiModalEmotionRecognition/")

Mounted at /content/drive


In [3]:
%cd /content/drive/MyDrive/NLP/MultiModalEmotionRecognition/data

/content/drive/MyDrive/NLP/MultiModalEmotionRecognition/data


In [4]:
!ls

correct_indexes   error_indexes		saved_features	    train_ende.zip
dev.zip		  image_index_test.txt	sentiment_test.txt
english_test.txt  image_index_val.txt	sentiment_val.txt
english_val.txt   images		test.zip


In [None]:
!unzip dev.zip
!unzip test.zip
!unzip train_ende.zip

In [None]:
%mv dev/ images/val
%mv test/ images/test
%mv train_ende/ images/train

In [None]:
%ls images/test -1 | wc -l

5067


In [6]:
%cd ..

/content/drive/MyDrive/NLP/MultiModalEmotionRecognition


In [8]:
import torch

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)

cuda:0


In [9]:
import torch
import numpy as np
from gluoncv import model_zoo, data
from gluoncv.data.transforms.pose import detector_to_simple_pose, heatmap_to_coord

class PoseEmbeddingExtractor:
    def __init__(
        self,
    ):
        self.detector = model_zoo.get_model('yolo3_mobilenet1.0_coco', pretrained=True)
        self.pose_net = model_zoo.get_model('simple_pose_resnet18_v1b', pretrained=True)
        self.detector.reset_class(["person"], reuse_weights=['person'])

    def detect_person(self, x, image):
        class_IDs, scores, bounding_boxs = self.detector(x)
        pose_input, upscale_bbox = detector_to_simple_pose(image, class_IDs, scores, bounding_boxs)
        return pose_input, upscale_bbox

    def get_most_confident_coords(self, predicted_heatmap, upscale_bbox):
        pred_coords, confidence = heatmap_to_coord(predicted_heatmap, upscale_bbox)
        mean_confidence = np.mean(confidence[:,:,0].asnumpy(), axis=1)
        best_confidence_arg = mean_confidence.argmax()
        best_coords = pred_coords[best_confidence_arg].asnumpy().ravel()
        return best_coords

    def extract_embedding(self, image_path):
        x, image = data.transforms.presets.ssd.load_test(image_path, short=512)
        pose_input, upscale_bbox = self.detect_person(x, image)
        predicted_heatmap = self.pose_net(pose_input)
        best_coords = self.get_most_confident_coords(predicted_heatmap, upscale_bbox)
        return torch.tensor(best_coords)




#Face Embedding

In [28]:
from scipy.spatial.distance import euclidean
import math
from skimage.transform import rotate
from mtcnn import MTCNN
import mediapipe
import numpy as np
import pandas as pd
import cv2
import os
from PIL import Image
import torch
from torchvision import transforms
import urllib


def get_model_path(model_name):
    model_file = model_name + ".pt"
    cache_dir = os.path.join(os.path.expanduser("~"), ".hsemotions")
    # cache_dir = "emotion_models"
    os.makedirs(cache_dir, exist_ok=True)
    fpath = os.path.join(cache_dir, model_file)
    if not os.path.isfile(fpath):
        print(f"{model_file} not exists")
        url = (
            "https://github.com/HSE-asavchenko/face-emotion-recognition/blob/main/models/affectnet_emotions/"
            + model_file
            + "?raw=true"
        )
        print("Downloading", model_name, "from", url)
        urllib.request.urlretrieve(url, fpath)

    return fpath


class FaceAlignment:
    def __init__(
        self,
    ):
        pass

    @staticmethod
    def apply_rotation_on_images(input_images, angles):
        rotated_images = [
            rotate(image, angle) for image, angle in zip(input_images, angles)
        ]
        return rotated_images

    @staticmethod
    def compute_alignment_rotation_(eyes_coordinates):
        angles = []
        directions = []
        for left_eye_coordinate, right_eye_coordinate in eyes_coordinates:

            left_eye_x, left_eye_y = left_eye_coordinate
            right_eye_x, right_eye_y = right_eye_coordinate

            triangle_vertex = (
                (right_eye_x, left_eye_y)
                if left_eye_y > right_eye_y
                else (left_eye_x, right_eye_y)
            )
            direction = (
                -1 if left_eye_y > right_eye_y else 1
            )  # rotate clockwise else counter-clockwise

            # compute length of triangle edges
            a = euclidean(left_eye_coordinate, triangle_vertex)
            b = euclidean(right_eye_coordinate, triangle_vertex)
            c = euclidean(right_eye_coordinate, left_eye_coordinate)

            # cosine rule
            if (
                b != 0 and c != 0
            ):  # this multiplication causes division by zero in cos_a calculation
                cos_a = (b**2 + c**2 - a**2) / (2 * b * c)
                angle = np.arccos(cos_a)  # angle in radian
                angle = (angle * 180) / math.pi  # radian to degree
            else:
                angle = 0

            angle = angle - 90 if direction == -1 else angle

            angles.append(angle)
            directions.append(direction)

        return angles, directions


class FaceDetection:

    # first call extract_face
    def __init__(self, model_name, minimum_confidence):

        self.detected_faces_information = None
        self.model_name = model_name
        self.minimum_confidence = minimum_confidence

        if model_name == "MTCNN":
            detector_model = MTCNN()
            self.detect_faces_function = (
                lambda input_image: detector_model.detect_faces(input_image)
            )

    def extract_faces(self, input_image, return_detections_information=True):
        self.detect_faces__(input_image)
        faces = self.get_faces__(
            input_image,
        )
        if return_detections_information:
            return faces, self.detected_faces_information

        else:
            return faces

    def detect_faces__(self, input_image):
        detections = self.detect_faces_function(input_image)
        self.detected_faces_information = list(
            filter(
                lambda element: element["confidence"] > self.minimum_confidence,
                detections,
            )
        )

    def get_detected_faces_information(self):
        return self.detected_faces_information

    def get_keypoints(
        self,
    ):
        return list(
            map(lambda element: element["keypoints"], self.detected_faces_information)
        )

    def get_faces__(
        self,
        input_image,
    ):
        boxes = [
            detection_information["box"]
            for detection_information in self.detected_faces_information
        ]
        y1y2x1x2 = [(int(y), int(y + h), int(x), int(x + w)) for x, y, w, h in boxes]
        faces = [input_image[y1:y2, x1:x2] for y1, y2, x1, x2 in y1y2x1x2]
        return faces

    def get_eyes_coordinates(
        self,
    ):
        eyes_coordinates = [
            (info["keypoints"]["left_eye"], info["keypoints"]["right_eye"])
            for info in self.detected_faces_information
        ]
        return eyes_coordinates


class FaceEmotionRecognizer:
    # supported values of model_name: enet_b0_8_best_vgaf, enet_b0_8_best_afew, enet_b2_8, enet_b0_8_va_mtl, enet_b2_7
    def __init__(self, device, model_name="enet_b0_8_best_vgaf"):
        self.device = device
        self.is_mtl = "_mtl" in model_name
        if "_7" in model_name:
            self.idx_to_class = {
                0: "Anger",
                1: "Disgust",
                2: "Fear",
                3: "Happiness",
                4: "Neutral",
                5: "Sadness",
                6: "Surprise",
            }
        else:
            self.idx_to_class = {
                0: "Anger",
                1: "Contempt",
                2: "Disgust",
                3: "Fear",
                4: "Happiness",
                5: "Neutral",
                6: "Sadness",
                7: "Surprise",
            }

        self.img_size = 224 if "_b0_" in model_name else 260
        self.test_transforms = transforms.Compose(
            [
                transforms.Resize((self.img_size, self.img_size)),
                transforms.ToTensor(),
                transforms.Normalize(
                    mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
                ),
            ]
        )

        path = get_model_path(model_name)

        model = torch.load(path)
        model = model.to(device)

        if isinstance(model.classifier, torch.nn.Sequential):
            self.classifier_weights = model.classifier[0].weight.data
            self.classifier_bias = model.classifier[0].bias.data
        else:
            self.classifier_weights = model.classifier.weight.data
            self.classifier_bias = model.classifier.bias.data

        model.classifier = torch.nn.Identity()
        self.model = model.eval()
        # print(path, self.test_transforms)

    def compute_probability(self, features):
        return torch.matmul(features, self.classifier_weights.T) + self.classifier_bias

    def extract_representations_from_faces(self, input_faces):
        faces = [self.test_transforms(Image.fromarray(face)) for face in input_faces]
        features = self.model(torch.stack(faces, dim=0).to(self.device))
        return features

    def predict_emotions_from_representations(
        self, representations, logits=True, return_features=True
    ):
        scores = self.compute_probability(representations)
        if self.is_mtl:
            predictions_indices = torch.argmax(scores[:, :-2], dim=1)

        else:
            predictions_indices = torch.argmax(scores, dim=1)

        if self.is_mtl:
            x = scores[:, :-2]

        else:
            x = scores
        pred = torch.argmax(x[0])

        if not logits:
            e_x = torch.exp(x - torch.max(x, dim=1)[:, None])
            e_x = e_x / e_x.sum(dim=1)[:, None]
            if self.is_mtl:
                scores[:, :-2] = e_x
            else:
                scores = e_x

        return [
            self.idx_to_class[pred.item()] for pred in (predictions_indices)
        ], scores


class FaceNormalizer:
    def __init__(self):
        self.mp_face_mesh = mediapipe.solutions.face_mesh
        face_mesh = self.mp_face_mesh.FaceMesh(static_image_mode=True)

        mp_face_mesh = mediapipe.solutions.face_mesh
        self.face_mesh = mp_face_mesh.FaceMesh(static_image_mode=True)
        self.routes_idx = self.initialize__()

    def initialize__(self):
        df = pd.DataFrame(
            list(self.mp_face_mesh.FACEMESH_FACE_OVAL), columns=["p1", "p2"]
        )
        routes_idx = []

        p1 = df.iloc[0]["p1"]
        p2 = df.iloc[0]["p2"]

        for i in range(0, df.shape[0]):
            obj = df[df["p1"] == p2]
            p1 = obj["p1"].values[0]
            p2 = obj["p2"].values[0]

            route_idx = []
            route_idx.append(p1)
            route_idx.append(p2)
            routes_idx.append(route_idx)

        return routes_idx

    def get_landmarks__(self, input_image: np.ndarray):
        if input_image.dtype == np.float:
            input_image = (input_image * 255).astype(np.uint8)

        results = self.face_mesh.process(input_image)
        landmarks = results.multi_face_landmarks[0]

        routes = []
        # for source_idx, target_idx in mp_face_mesh.FACEMESH_FACE_OVAL:
        for source_idx, target_idx in self.routes_idx:
            source = landmarks.landmark[source_idx]
            target = landmarks.landmark[target_idx]

            relative_source = (
                int(input_image.shape[1] * source.x),
                int(input_image.shape[0] * source.y),
            )
            relative_target = (
                int(input_image.shape[1] * target.x),
                int(input_image.shape[0] * target.y),
            )

            # cv2.line(img, relative_source, relative_target, (255, 255, 255), thickness = 2)

            routes.append(relative_source)
            routes.append(relative_target)

        return routes

    @staticmethod
    def normalize_with_landmark_points__(input_image, landmarks):
        mask = np.zeros((input_image.shape[0], input_image.shape[1]))
        mask = cv2.fillConvexPoly(mask, np.array(landmarks), 1)
        mask = mask.astype(bool)

        out = np.zeros_like(input_image)
        out[mask] = input_image[mask]
        return out

    def normalize_faces_image(self, input_images):
        normalized_faces_images = [
            self.normalize_with_landmark_points__(
                input_image, self.get_landmarks__(input_image)
            )
            for input_image in input_images
        ]
        return normalized_faces_images


class FaceEmbeddingExtractor:
    def __init__(
        self,
    ):
        self.face_detection_model: FaceDetection = None
        self.face_alignment_model: FaceAlignment = None
        self.face_normalizer_model: FaceNormalizer = None
        self.face_emotion_recognition_model: FaceEmotionRecognizer = None

        self.faces = None
        self.normalized_rotated_faces = None
        self.rotated_faces = None
        self.rotation_angles = None
        self.rotation_directions = None

    def set_face_detection_model(self, face_detection_model):
        self.face_detection_model = face_detection_model
        return self

    def set_face_alignment_model(self, face_alignment_model):
        self.face_alignment_model = face_alignment_model
        return self

    def set_face_normalizer_model(self, face_normalizer_model):
        self.face_normalizer_model = face_normalizer_model
        return self

    def set_face_emotion_recognition_model(self, face_emotion_recognition_model):
        self.face_emotion_recognition_model = face_emotion_recognition_model
        return self

    def extract_embedding(self, input_image):
        faces, detected_faces_information = self.face_detection_model.extract_faces(
            input_image, return_detections_information=True
        )

        (
            rotation_angles,
            rotation_directions,
        ) = self.face_alignment_model.compute_alignment_rotation_(
            self.face_detection_model.get_eyes_coordinates()
        )
        rotated_faces = self.face_alignment_model.apply_rotation_on_images(
            faces, rotation_angles
        )
        normalized_rotated_faces = self.face_normalizer_model.normalize_faces_image(
            rotated_faces
        )

        normalized_rotated_faces_255 = [
            (image * 255).astype(np.uint8) for image in normalized_rotated_faces
        ]

        representations = (
            self.face_emotion_recognition_model.extract_representations_from_faces(
                normalized_rotated_faces_255
            )
        )[0] #WARNING: 0 was not here
        del normalized_rotated_faces_255
        del normalized_rotated_faces
        del rotated_faces
        del rotation_angles
        del rotation_directions
        del faces
        del detected_faces_information
        # (
        #     predictions,
        #     scores,
        # ) = self.face_emotion_recognition_model.predict_emotions_from_representations(
        #     representations
        # )

        # self.faces = faces
        # self.rotation_angles, self.rotation_directions = (
        #     rotation_angles,
        #     rotation_directions,
        # )
        # self.rotated_faces = rotated_faces
        # self.normalized_rotated_faces = normalized_rotated_faces_255

        return None, None, representations
        # return preictions, scores, representations

    def get_rotations_information(self):
        return self.rotation_angles, self.rotation_directions

    def get_faces(self):
        return self.faces

    def get_rotated_faces(self):
        return self.rotated_faces

    def get_normalized_rotated_faces(self):
        return self.normalized_rotated_faces

    def clear(self):
        self.faces = None
        self.normalized_rotated_faces = None
        self.rotated_faces = None
        self.rotation_angles = None
        self.rotation_directions = None

    def store_embeddings(self, file, embeddings):
        with open(file, "wb") as file_out:
            pickle.dump(
                {"embeddings": embeddings}, file_out, protocol=pickle.HIGHEST_PROTOCOL
            )

    def load_embeddings(self, file):
        with open(file, "rb") as file_in:
            stored_data = pickle.load(file_in)
            stored_embeddings = stored_data["embeddings"]

        return stored_embeddings

#Text Embedding


In [29]:
from transformers import AutoTokenizer, AutoModel, pipeline
from transformers import RobertaForSequenceClassification
import torch
import pickle


class TextEmbeddingExtractor:
    def __init__(
        self,
        model_name="pysentimiento/robertuito-sentiment-analysis",
        batch_size=250,
        show_progress_bar=True,
        to_tensor=True,
        max_length=128,
    ):
        self.model_name = model_name

        self.device = device

        self.batch_size = batch_size
        self.show_progress_bar = show_progress_bar
        self.to_tensor = to_tensor

        self.max_length = max_length

        self.tokenizer = AutoTokenizer.from_pretrained(self.model_name)
        # self.model = AutoModel.from_pretrained(self.model_name).to(self.device)

        self.model = RobertaForSequenceClassification.from_pretrained(
            self.model_name, num_labels=3, output_hidden_states=True
        ).to(self.device)

        # C1
        self.generator = pipeline(
            task="sentiment-analysis",
            model=self.model,
            tokenizer=self.tokenizer,
        )

    def extract_embedding(
        self,
        input_batch_sentences,
    ):
        encoded_input = self.tokenizer(
            input_batch_sentences,
            padding=True,
            truncation=True,
            max_length=self.max_length,
            return_tensors="pt",
        ).to(self.device)

        with torch.no_grad():
            model_output = self.model(**encoded_input)
            hidden_states = model_output["hidden_states"]
            last_layer_hidden_states = hidden_states[
                12
            ]  # 12 = len(hidden_states) , dim = (batch_size, seq_len, 768)
            cls_hidden_state = last_layer_hidden_states[:, 0, :]

        return cls_hidden_state

    def get_labels(self, input_batch_sentences):
        return self.generator(input_batch_sentences)

    @staticmethod
    def store_embeddings(file, embeddings):
        with open(file, "wb") as file_out:
            pickle.dump(
                {"embeddings": embeddings}, file_out, protocol=pickle.HIGHEST_PROTOCOL
            )

    @staticmethod
    def load_embeddings(file):
        with open(file, "rb") as file_in:
            stored_data = pickle.load(file_in)
            stored_embeddings = stored_data["embeddings"]

        return stored_embeddings

#Dataset

In [30]:
!free

              total        used        free      shared  buff/cache   available
Mem:       13298580     7093508      164564       21704     6040508     9269192
Swap:             0           0           0


In [35]:
import os, cv2, torch, ast
import pandas as pd
import numpy as np
from pathlib import Path
from torch.utils.data import Dataset
from tqdm import trange
from tqdm import tqdm


class MSCTDDataSet(Dataset):
    """MSCTD dataset."""

    def __init__(self, base_path="data/", dataset_type="train", data_size=None, load=False, raw=False):
        """
        Args:
            base_path (str or path): path to data folder
            dataset_type (str): dev, train, test
        """
        if isinstance(base_path, str):
            base_path = Path(base_path)
        self.base_path = base_path
        self.load_path = base_path / 'saved_features'
        self.dataset_type = dataset_type
        self.text_file_path = base_path / f"english_{dataset_type}.txt"
        self.seq_file_path = base_path / f"image_index_{dataset_type}.txt"
        self.sentiment_file_path = base_path / f"sentiment_{dataset_type}.txt"
        self.image_dir = base_path / "images" / dataset_type
        self.correct_indexes_file_path = base_path / "correct_indexes" / f"correct_indexes_{dataset_type}.txt"

        self.data_size = data_size
        self.load = load
        self.raw = raw

        self.texts = None
        self.sentiments = None
        self.indexes = None
        self.face_embeddings = None
        self.pose_embeddings = None
        self.text_embeddings = None
        self.load_data()
        self.face_embedding_extractor = self.get_face_embedding_extractor()
        self.text_embedding_extractor = TextEmbeddingExtractor()
        self.pose_embedding_extractor = PoseEmbeddingExtractor()


    def get_face_embedding_extractor(self):
        fd = FaceDetection("MTCNN", minimum_confidence=0.95)
        fa = FaceAlignment()
        fn = FaceNormalizer()
        model_name = "enet_b0_8_best_afew"
        fer = FaceEmotionRecognizer(device, model_name)
        fre = (
            FaceEmbeddingExtractor()
            .set_face_detection_model(fd)
            .set_face_alignment_model(fa)
            .set_face_normalizer_model(fn)
            .set_face_emotion_recognition_model(fer)
        )
        return fre

    def load_data(self):
        with open(self.text_file_path) as text_file, open(self.sentiment_file_path) as sentiment_file, open(self.correct_indexes_file_path) as correct_file:
            texts = [t.strip() for t in text_file.readlines()]
            sentiments = [int(t.strip()) for t in sentiment_file.readlines()]
            face_embeddings = None
            pose_embeddings = None
            text_embeddings = None
            corrects = [int(c.strip()) for c in correct_file.readlines()]
            if self.load:
                try:
                    face_embeddings = torch.load(self.save_path / f'face_embeddings_{self.dataset_type}.pt')
                    pose_embeddings = torch.load(self.save_path / f'pose_embeddings_{self.dataset_type}.pt')
                    text_embeddings = torch.load(self.save_path / f'text_embeddings_{self.dataset_type}.pt')
                    corrects = torch.load(self.save_path / f'real_indexes_{self.dataset_type}.pt')
                except Exception as e:
                    print(e)
                    print('Warning: passed load=True but not embedding file was located. Not loading')

            correct_texts = [texts[i] for i in corrects]
            correct_sentiments = [sentiments[i] for i in corrects]
        # with open(self.image_index_path) as f:
        #     images = [ast.literal_eval(t.strip()) for t in f.readlines()]

        if self.data_size:
            correct_texts = correct_texts[: self.data_size]
            correct_sentiments =correct_sentiments[: self.data_size]
            if face_embeddings:
                face_embeddings = face_embeddings[:self.data_size,:]
            if pose_embeddings:
                pose_embeddings = pose_embeddings[:self.data_size,:]
            if text_embeddings:
                face_embeddings = text_embeddings[:self.data_size,:]
            # images = images[: self.data_size]


        self.texts = correct_texts
        self.text_embeddings = text_embeddings
        self.sentiments = correct_sentiments
        self.indexes = corrects
        self.face_embeddings = face_embeddings
        self.pose_embeddings = pose_embeddings



    def __len__(self):
        return len(self.texts)

    def get_face_embedding(self, index, image):
        if self.load:
            print('loading img')
            return self.face_embeddings[index]
        (
            predictions,
            scores,
            representations,
        ) = self.face_embedding_extractor.extract_embedding(image)
        return representations

    def get_pose_embedding(self, index, image):
        if self.load:
            print('loading img')
            return self.pose_embeddings[index]
        return self.pose_embedding_extractor.extract_embedding(image)

    def get_image_embeddings(self, index):
        image = None
        real_index = self.indexes[index]
        image_name = self.image_dir / f"{real_index}.jpg"
        if not self.load:
            image = cv2.imread(str(image_name))[:, :, ::-1]

        face_embedding = self.get_face_embedding(index, image)
        pose_embedding = self.get_pose_embedding(index, str(image_name))
        return face_embedding, pose_embedding

    def get_sentiment(self, index):
        return self.sentiments[index]



    def get_text(self, index):
        if self.load and not self.text_embeddings is None:
            print('loading txt')
            return self.text_embeddings[index]
        text = self.texts[index]
        text = self.text_embedding_extractor.extract_embedding([text])[0]
        return text

    def __getitem__(self, index):
        if torch.is_tensor(index):
            index = index.tolist()
        
        face_embedding, pose_embedding = self.get_image_embeddings(index)
        sentiment = self.get_sentiment(index)
        text = self.get_text(index)
        sample = {"real_index": self.indexes[index], "pose_embedding": pose_embedding, "face_embedding": face_embedding, "text_embedding": text, "sentiment": sentiment}
        if self.raw:
            sample["text"] = self.texts[index]
            # sample["image"] = pass 

        return sample


In [36]:
# %mkdir data/saved_features/

In [37]:
from torch.utils.data import DataLoader

ds = MSCTDDataSet(base_path=project_path / "data/", dataset_type = "val", load=False)
print(len(ds))
dataloader = DataLoader(ds, batch_size=8)

3461


In [None]:
def save_features(dataloader, dataset_type):
    save_path = project_path / 'data' / 'saved_features'
    # stop_batch = 6

    for batch_index, batch in enumerate(tqdm(dataloader)):
        # print(batch["face_embedding"].shape)
        # print(batch["text_embedding"].shape)
        # print(batch["real_index"].shape)
        torch.save(batch["face_embedding"], save_path / f'face_embeddings_{dataset_type}_{batch_index}.pt')
        torch.save(batch["pose_embedding"], save_path / f'pose_embeddings_{dataset_type}_{batch_index}.pt')
        torch.save(batch["text_embedding"], save_path / f'text_embeddings_{dataset_type}_{batch_index}.pt')
        torch.save(batch["real_index"], save_path / f'real_indexes_{dataset_type}_{batch_index}.pt')
        # if batch_index==stop_batch:
        #   break

    print('----------------------')
    print(len(dataloader))
    len_batch = len(dataloader)
    # len_batch = stop_batch
    face_embeddings = []
    for i in range(len_batch):
        face_embeddings.append(torch.load(save_path / f'face_embeddings_{dataset_type}_{i}.pt'))
    face_embeddings = torch.cat(face_embeddings, dim=0)
    print(face_embeddings.shape)
    torch.save(face_embeddings, save_path / f'face_embeddings_{dataset_type}.pt')
    del face_embeddings

    pose_embeddings = []
    for i in range(len_batch):
        pose_embeddings.append(torch.load(save_path / f'pose_embeddings+{dataset_type}_{i}.pt'))
    pose_embeddings = torch.cat(pose_embeddings, dim=0)
    print(pose_embeddings.shape)
    torch.save(pose_embeddings, save_path / f'pose_embeddings_{dataset_type}.pt')
    del pose_embeddings

    text_embeddings = []
    for i in range(len_batch):
        text_embeddings.append(torch.load(save_path / f'text_embeddings_{dataset_type}_{i}.pt'))
    text_embeddings = torch.cat(text_embeddings, dim=0)
    print(text_embeddings.shape)
    torch.save(text_embeddings, save_path / f'text_embeddings_{dataset_type}.pt')
    del text_embeddings

    real_indexes = []
    for i in range(len_batch):
        real_indexes.append(torch.load(save_path / f'real_indexes_{dataset_type}_{i}.pt'))
    real_indexes = torch.cat(real_indexes, dim=0)
    print(real_indexes.shape)
    torch.save(real_indexes, save_path / f'real_indexes_{dataset_type}.pt')
    del real_indexes

save_features(dataloader, "val")

  0%|          | 0/433 [00:00<?, ?it/s]







Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations




  0%|          | 1/433 [00:20<2:30:36, 20.92s/it]



  0%|          | 2/433 [00:40<2:25:04, 20.20s/it]



  1%|          | 3/433 [01:00<2:22:33, 19.89s/it]



  1%|          | 4/433 [01:21<2:25:12, 20.31s/it]



  1%|          | 5/433 [01:39<2:20:14, 19.66s/it]



  1%|▏         | 6/433 [01:58<2:16:59, 19.25s/it]



  2%|▏         | 7/433 [02:16<2:14:06, 18.89s/it]



  2%|▏         | 8/433 [02:34<2:12:29, 18.70s/it]



  2%|▏         | 9/433 [02:52<2:11:29, 18.61s/it]



  2%|▏         | 10/433 [03:11<2:11:25, 18.64s/it]



  3%|▎         | 11/433 [03:32<2:14:58, 19.19s/it]



  3%|▎         | 12/433 [03:55<2:23:00, 20.38s/it]



  3%|▎         | 13/433 [04:15<2:22:30, 20.36s/it]



  3%|▎         | 14/433 [04:36<2:22:50, 20.45s/it]



  3%|▎         | 15/433 [04:56<2:22:54, 20.51s/it]



  4%|▎         | 16/433 [05:17<2:23:05, 20.59s/it]



  4%|▍         | 17/433 [05:37<2:22:14, 20.52s/it]



  4%|▍         | 18/433 [05:59<2:23:25, 20.74s/it]



  4%|▍         | 19/433 [06:19<2:23:12, 20.75s/it]



  5%|▍         | 20/433 [06:40<2:22:20, 20.68s/it]



  5%|▍         | 21/433 [07:01<2:23:14, 20.86s/it]



  5%|▌         | 22/433 [07:27<2:32:04, 22.20s/it]



  5%|▌         | 23/433 [07:47<2:28:27, 21.73s/it]



  6%|▌         | 24/433 [08:07<2:25:13, 21.30s/it]



  6%|▌         | 25/433 [08:27<2:22:11, 20.91s/it]



  6%|▌         | 26/433 [08:47<2:19:49, 20.61s/it]



  6%|▌         | 27/433 [09:08<2:18:33, 20.48s/it]



  6%|▋         | 28/433 [09:27<2:16:15, 20.19s/it]



  7%|▋         | 29/433 [09:47<2:16:05, 20.21s/it]



  7%|▋         | 30/433 [10:07<2:14:32, 20.03s/it]



  7%|▋         | 31/433 [10:28<2:16:45, 20.41s/it]



  7%|▋         | 32/433 [10:48<2:14:20, 20.10s/it]



  8%|▊         | 33/433 [11:09<2:15:59, 20.40s/it]



  8%|▊         | 34/433 [11:30<2:17:45, 20.71s/it]



  8%|▊         | 35/433 [11:52<2:18:38, 20.90s/it]



  8%|▊         | 36/433 [12:12<2:17:29, 20.78s/it]



  9%|▊         | 37/433 [12:34<2:20:25, 21.28s/it]



  9%|▉         | 38/433 [12:57<2:22:16, 21.61s/it]



  9%|▉         | 39/433 [13:18<2:21:06, 21.49s/it]



  9%|▉         | 40/433 [13:39<2:20:37, 21.47s/it]



  9%|▉         | 41/433 [14:00<2:17:30, 21.05s/it]



 10%|▉         | 42/433 [14:19<2:14:20, 20.62s/it]



 10%|▉         | 43/433 [14:39<2:13:17, 20.51s/it]



 10%|█         | 44/433 [14:59<2:10:59, 20.21s/it]



 10%|█         | 45/433 [15:18<2:08:49, 19.92s/it]



 11%|█         | 46/433 [15:38<2:09:16, 20.04s/it]



 11%|█         | 47/433 [15:56<2:04:32, 19.36s/it]



 11%|█         | 48/433 [16:14<2:01:50, 18.99s/it]



 11%|█▏        | 49/433 [16:32<1:59:14, 18.63s/it]



 12%|█▏        | 50/433 [16:50<1:57:02, 18.34s/it]



 12%|█▏        | 51/433 [17:09<1:57:55, 18.52s/it]



 12%|█▏        | 52/433 [17:27<1:57:40, 18.53s/it]



 12%|█▏        | 53/433 [17:45<1:56:26, 18.39s/it]



 12%|█▏        | 54/433 [18:05<1:57:51, 18.66s/it]



 13%|█▎        | 55/433 [18:24<1:58:53, 18.87s/it]



 13%|█▎        | 56/433 [18:44<2:00:55, 19.25s/it]



 13%|█▎        | 57/433 [19:05<2:02:54, 19.61s/it]



 13%|█▎        | 58/433 [19:25<2:03:53, 19.82s/it]



 14%|█▎        | 59/433 [19:46<2:05:47, 20.18s/it]



 14%|█▍        | 60/433 [20:06<2:05:07, 20.13s/it]



 14%|█▍        | 61/433 [20:26<2:05:08, 20.18s/it]



 14%|█▍        | 62/433 [20:47<2:05:01, 20.22s/it]



 15%|█▍        | 63/433 [21:06<2:03:11, 19.98s/it]



 15%|█▍        | 64/433 [21:26<2:02:33, 19.93s/it]



 15%|█▌        | 65/433 [21:45<2:01:48, 19.86s/it]



 15%|█▌        | 66/433 [22:06<2:02:40, 20.05s/it]



 15%|█▌        | 67/433 [22:25<1:59:49, 19.64s/it]



 16%|█▌        | 68/433 [22:45<2:01:06, 19.91s/it]



 16%|█▌        | 69/433 [23:03<1:56:57, 19.28s/it]



 16%|█▌        | 70/433 [23:23<1:57:56, 19.50s/it]



 16%|█▋        | 71/433 [23:43<1:58:50, 19.70s/it]



 17%|█▋        | 72/433 [24:03<1:59:17, 19.83s/it]



 17%|█▋        | 73/433 [24:25<2:02:17, 20.38s/it]



 17%|█▋        | 74/433 [24:45<2:01:31, 20.31s/it]



 17%|█▋        | 75/433 [25:05<1:59:39, 20.05s/it]



 18%|█▊        | 76/433 [25:25<1:59:19, 20.05s/it]



 18%|█▊        | 77/433 [25:45<2:00:11, 20.26s/it]



 18%|█▊        | 78/433 [26:05<1:59:04, 20.12s/it]



 18%|█▊        | 79/433 [26:25<1:57:52, 19.98s/it]



 18%|█▊        | 80/433 [26:46<2:00:20, 20.46s/it]



 19%|█▊        | 81/433 [27:06<1:59:17, 20.33s/it]



 19%|█▉        | 82/433 [27:27<1:58:35, 20.27s/it]



 19%|█▉        | 83/433 [27:47<1:57:56, 20.22s/it]



 19%|█▉        | 84/433 [28:08<1:58:48, 20.43s/it]



 20%|█▉        | 85/433 [28:30<2:01:48, 21.00s/it]



 20%|█▉        | 86/433 [28:53<2:04:38, 21.55s/it]



 20%|██        | 87/433 [29:15<2:04:37, 21.61s/it]



 20%|██        | 88/433 [29:37<2:05:14, 21.78s/it]



 21%|██        | 89/433 [30:00<2:06:49, 22.12s/it]



 21%|██        | 90/433 [30:22<2:07:45, 22.35s/it]



 21%|██        | 91/433 [30:46<2:09:09, 22.66s/it]



 21%|██        | 92/433 [31:08<2:07:58, 22.52s/it]



 21%|██▏       | 93/433 [31:30<2:06:40, 22.35s/it]



 22%|██▏       | 94/433 [31:52<2:05:25, 22.20s/it]



 22%|██▏       | 95/433 [32:14<2:05:02, 22.20s/it]



 22%|██▏       | 96/433 [32:36<2:04:45, 22.21s/it]



 22%|██▏       | 97/433 [32:58<2:03:14, 22.01s/it]



 23%|██▎       | 98/433 [33:19<2:00:47, 21.64s/it]



 23%|██▎       | 99/433 [33:41<2:02:29, 22.00s/it]



 23%|██▎       | 100/433 [34:03<2:01:46, 21.94s/it]



 23%|██▎       | 101/433 [34:25<2:01:39, 21.99s/it]



 24%|██▎       | 102/433 [34:48<2:02:19, 22.17s/it]



 24%|██▍       | 103/433 [35:09<1:59:18, 21.69s/it]



 24%|██▍       | 104/433 [35:30<1:58:52, 21.68s/it]



 24%|██▍       | 105/433 [35:51<1:57:14, 21.45s/it]



 24%|██▍       | 106/433 [36:13<1:56:54, 21.45s/it]



 25%|██▍       | 107/433 [36:33<1:54:56, 21.16s/it]



 25%|██▍       | 108/433 [36:54<1:54:22, 21.12s/it]



 25%|██▌       | 109/433 [37:16<1:55:02, 21.30s/it]



 25%|██▌       | 110/433 [37:37<1:54:29, 21.27s/it]



 26%|██▌       | 111/433 [37:58<1:54:00, 21.24s/it]



 26%|██▌       | 112/433 [38:20<1:54:10, 21.34s/it]



 26%|██▌       | 113/433 [38:42<1:55:32, 21.66s/it]



 26%|██▋       | 114/433 [39:04<1:55:32, 21.73s/it]



 27%|██▋       | 115/433 [39:26<1:55:35, 21.81s/it]



 27%|██▋       | 116/433 [39:48<1:55:53, 21.93s/it]



 27%|██▋       | 117/433 [40:13<1:59:36, 22.71s/it]



 27%|██▋       | 118/433 [40:37<2:00:52, 23.02s/it]



 27%|██▋       | 119/433 [41:01<2:02:45, 23.46s/it]



 28%|██▊       | 120/433 [41:23<2:00:37, 23.12s/it]



 28%|██▊       | 121/433 [41:46<1:59:34, 23.00s/it]



 28%|██▊       | 122/433 [42:08<1:58:12, 22.81s/it]



 28%|██▊       | 123/433 [42:31<1:57:17, 22.70s/it]



 29%|██▊       | 124/433 [42:54<1:57:03, 22.73s/it]



 29%|██▉       | 125/433 [43:17<1:58:18, 23.05s/it]



 29%|██▉       | 126/433 [43:40<1:57:34, 22.98s/it]



 29%|██▉       | 127/433 [44:04<1:57:59, 23.13s/it]



 30%|██▉       | 128/433 [44:26<1:56:30, 22.92s/it]



 30%|██▉       | 129/433 [44:48<1:54:48, 22.66s/it]



 30%|███       | 130/433 [45:09<1:51:35, 22.10s/it]



 30%|███       | 131/433 [45:29<1:48:08, 21.49s/it]



 30%|███       | 132/433 [45:50<1:47:08, 21.36s/it]



 31%|███       | 133/433 [46:11<1:45:38, 21.13s/it]



 31%|███       | 134/433 [46:31<1:43:51, 20.84s/it]



 31%|███       | 135/433 [46:51<1:42:42, 20.68s/it]



 31%|███▏      | 136/433 [47:11<1:41:00, 20.40s/it]



In [None]:
%ls data/saved_features

In [None]:
%rm data/saved_features/*

In [None]:
!du data/saved_features/* -h

81K	data/saved_features/face_embeddings_val_0.pt
81K	data/saved_features/face_embeddings_val_1.pt
81K	data/saved_features/face_embeddings_val_2.pt
81K	data/saved_features/face_embeddings_val_3.pt
81K	data/saved_features/face_embeddings_val_4.pt
81K	data/saved_features/face_embeddings_val_5.pt
81K	data/saved_features/face_embeddings_val_6.pt
481K	data/saved_features/face_embeddings_val.pt
1.0K	data/saved_features/real_indexes_val_0.pt
1.0K	data/saved_features/real_indexes_val_1.pt
1.0K	data/saved_features/real_indexes_val_2.pt
1.0K	data/saved_features/real_indexes_val_3.pt
1.0K	data/saved_features/real_indexes_val_4.pt
1.0K	data/saved_features/real_indexes_val_5.pt
1.0K	data/saved_features/real_indexes_val_6.pt
1.5K	data/saved_features/real_indexes_val.pt
49K	data/saved_features/text_embeddings_val_0.pt
49K	data/saved_features/text_embeddings_val_1.pt
49K	data/saved_features/text_embeddings_val_2.pt
49K	data/saved_features/text_embeddings_val_3.pt
49K	data/saved_features/text_embeddings

In [None]:
save_path = project_path / 'data' / 'saved_features'
dataset_type = "val"
face_embeddings = []
for i in range(3):
    face_embeddings.append(torch.load(save_path / f'face_embeddings_{dataset_type}_{i}.pt'))
face_embeddings = torch.cat(face_embeddings, dim=0)
print(face_embeddings.shape)
torch.save(face_embeddings, save_path / f'face_embeddings_{dataset_type}_{batch_index}.pt')
del face_embeddings

NameError: ignored

#Data Loader

In [None]:
class MSCTDDataLoader:
    def __init__(self, dl, device):
        self.dl = dl
        self.device = device

    def __iter__(self):
        for b in self.dl:
            yield to_device(b, self.device)

    def __len__(self):
        return len(self.dl)

def to_device(data, device):
    if isinstance(data, (list, tuple)):
        return [to_device(x, device) for x in data]
    if isinstance(data, dict):
        return {k: to_device(v, device) for k, v in data.items()}
    if isinstance(data, str):
        return data
    return data.to(device)

ds = MSCTDDataSet(base_path=project_path + "data/", dataset_type = "val", load=True)
dl = DataLoader(ds, batch_size=10)
dl = MSCTDDataLoader(dl, device)
for x in dl:
  print(x)
  print(x['face_embedding'].shape)
  print(x['text_embedding'].shape)
  print(x['real_index'])
  break

In [None]:
import torch
from torch import nn


class SimpleDenseNetwork(nn.Module):
    def __init__(self, n_classes, embedding_dimension):
        super(SimpleDenseNetwork, self).__init__()

        self.n_classes = n_classes
        self.embedding_dimension = embedding_dimension

        self.fc = nn.Sequential(
            nn.Linear(
                in_features=self.embedding_dimension,
                out_features=512,
            ),
            nn.ReLU(inplace=True),
            nn.Linear(in_features=512, out_features=128),
            nn.ReLU(inplace=True),
            nn.Linear(in_features=128, out_features=3),
            nn.ReLU(inplace=True),
            nn.Softmax(),
        )

    def forward(self, input_batch):
        x = input_batch
        x = self.fc(x)
        output_batch = x

        return output_batch

#Train

In [None]:
BATCH_SIZE = 32
num_workers = 1
EPOCHS = 1
embedding_dimension = 2048
learning_rate = 0.1
momentum = 0.9
data_size = 1000

In [None]:
import torch.optim as optim
from datetime import datetime
from torch.utils.data import DataLoader
from transformers import AutoTokenizer


def train_epoch(epoch_index, model, dataloader, loss_fn, optimizer):
    running_loss = 0.0
    last_loss = 0.0

    for data_pair_index, batch in enumerate(dataloader):
        print("--------------", data_pair_index, "-------------")
        text_embedding = batch["text_embedding"]
        face_embedding = batch["face"]
        # pose_embeding = batch["pose_embeding"]
        labels = batch["sentiment"]
        optimizer.zero_grad()

        outputs = model(torch.cat((text_embedding, face_embedding), 1))

        loss = loss_fn(outputs, labels)
        loss.backward()

        optimizer.step()

        # Gather data and report
        running_loss += loss.item()
        if data_pair_index % 1000 == 999:
            last_loss = running_loss / 1000  # loss per batch
            print("  batch {} loss: {}".format(data_pair_index + 1, last_loss))
            tb_x = epoch_index * len(dataloader) + data_pair_index + 1
            print("Loss/train", last_loss, tb_x)
            running_loss = 0.0

    return last_loss


def train_model(model, epochs, train_dataloader, val_dataloader):
    loss_fn = nn.CrossEntropyLoss()
    optimizer = optim.SGD(model.parameters(), lr=learning_rate, momentum=momentum)
    for epoch in trange(epochs):
        model.train()
        train_epoch(epoch, model, train_dataloader, loss_fn, optimizer)
        model.eval()
        validate(model, val_dataloader, loss_fn)

    return model

In [None]:
# model = SimpleDenseNetwork(n_classes=3, embedding_dimension=embedding_dimension).to(device=device)

In [None]:
val_dataset = MSCTDDataSet(data_size, project_path + "data/", "val")
val_dataloader = DataLoader(val_dataset, batch_size=BATCH_SIZE)
val_dataloader = MSCTDDataLoader(val_dataloader, device)

test_dataset = MSCTDDataSet(5000, project_path + "data/", "test")
test_dataloader = DataLoader(test_dataset, batch_size=BATCH_SIZE)
test_dataloader = MSCTDDataLoader(test_dataloader, device)

/root/.hsemotions/enet_b0_8_best_afew.pt Compose(
    Resize(size=(224, 224), interpolation=bilinear, max_size=None, antialias=None)
    ToTensor()
    Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
)
/root/.hsemotions/enet_b0_8_best_afew.pt Compose(
    Resize(size=(224, 224), interpolation=bilinear, max_size=None, antialias=None)
    ToTensor()
    Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
)


In [None]:
model = train_model(model, EPOCHS, val_dataloader, test_dataloader)

Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations


-------------- 0 -------------


  input = module(input)


-------------- 1 -------------
-------------- 2 -------------
-------------- 3 -------------
-------------- 4 -------------


#Evaluating

In [None]:
import evaluate

accuracy = evaluate.load("accuracy")
precision = evaluate.load("precision")


def validate(model, dataloader, loss_fn):
    running_loss = 0.0
    last_loss = 0.0

    for data_pair_index, batch in enumerate(dataloader):
        print("--------------", data_pair_index, "-------------")
        text_embedding = batch["text_embedding"]
        face_embedding = batch["face"]
        # pose_embeding = batch["pose_embeding"]
        labels = batch["sentiment"]

        logits = model(torch.cat((text_embedding, face_embedding), 1))
        # print(outputs)
        accuracy.add_batch(predictions=logits.argmax(dim=1), references=labels)
        precision.add_batch(predictions=logits.argmax(dim=1), references=labels)
        loss = loss_fn(logits, labels)
        running_loss += loss.item()
        # print(running_loss)
        # print('true answer',labels)
        # print('prediction',logits.argmax(dim=1))
        # if data_pair_index==2:
        #   break
    print(accuracy.compute())
    print(precision.compute(average=None))

In [None]:
validate(model, test_dataloader, nn.CrossEntropyLoss())

-------------- 0 -------------
-------------- 1 -------------
-------------- 2 -------------
-------------- 3 -------------
-------------- 4 -------------
-------------- 5 -------------
-------------- 6 -------------
-------------- 7 -------------
-------------- 8 -------------
-------------- 9 -------------
-------------- 10 -------------
-------------- 11 -------------
-------------- 12 -------------
-------------- 13 -------------
-------------- 14 -------------
-------------- 15 -------------
-------------- 16 -------------
-------------- 17 -------------
-------------- 18 -------------
-------------- 19 -------------
-------------- 20 -------------
-------------- 21 -------------
-------------- 22 -------------
-------------- 23 -------------
-------------- 24 -------------
-------------- 25 -------------
-------------- 26 -------------
-------------- 27 -------------
-------------- 28 -------------
-------------- 29 -------------
-------------- 30 -------------
-------------- 31 

  _warn_prf(average, modifier, msg_start, len(result))
