In [None]:
import torch
from torch import nn
from torchvision import models
from transformers import XLMRobertaModel, RobertaModel, RobertaTokenizer

# Model Architecture

In [None]:
class EncoderModule(nn.Module):
    def __init__(self, n_layers=6):
        super(EncoderModule).__init__()
        self.n_layers = n_layers
        self.encoders = nn.ModuleList()
        for _ in range(self.n_layers):
            self.encoders.append(nn.TransformerEncoderLayer(d_model=64, nhead=8, batch_first=True)) # d_model - длина последовательности, подумать над тем какую длину использовать

    def forward(self, x):
        for i in range(self.n_layers):
            x = x + self.encoders[i](x)
        return x

class FusionAttentionModule(nn.Module):
    def __init__(self, n_layers=5):
        self.n_layers = n_layers
        self.audio_attn = nn.ModuleList()
        self.video_attn = nn.ModuleList()
        self.fcs = nn.ModuleList()
        for _ in range(self.n_layers):
            self.audio_attn.append(nn.MultiHeadAttention(768, 8, kdim=128))
            self.video_attn.append(nn.MultiHeadAttention(768, 8, kdim=600))
            self.fcs.append(nn.Linear(768 + 600 + 128, 768))

    def forward(self, audio_emb, text_emb, video_emb):
        text_state = text_emb
        for i in range(self.n_layers):
            audio_attn_out = self.audio_attn[i](text_state, audio_emb, text_state)
            video_attn_out = self.video_attn[i](text_state, video_emb, text_state)
            emb_cat = torch.cat([audio_attn_out, text_state, video_attn_out], dim=1)
            text_state = self.fcs[i](emb_cat)
        return text_state

# Text Feature Extractor

In [None]:
# text feature extractor

from transformers import AutoTokenizer, AutoModel
tokenizer = AutoTokenizer.from_pretrained('tae898/emoberta-base')
model = AutoModel.from_pretrained('tae898/emoberta-base')

enc = tokenizer(["Привет как дела"], return_tensors='pt')
model(**enc).last_hidden_state.mean(dim=1).size() # last_hidden_state/pooler_output ?

Some weights of RobertaModel were not initialized from the model checkpoint at tae898/emoberta-base and are newly initialized: ['roberta.pooler.dense.bias', 'roberta.pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


torch.Size([1, 768])

In [None]:
class M2FNet(nn.Module):
    def __init__(self):
        super(M2FNet, self).__init__()

        # self.audio_feature_extractor = AudioFeatureExtractor()
        # self.text_feature_extractor = TextFeatureExtractor()
        # self.video_feature_extractor = VideoFeatureExtractor()

        self.audio_encoder = EncoderModule()
        self.text_encoder = EncoderModule()
        self.video_encoder = EncoderModule()

        self.fusion_attn = FusionAttentionModule()

        self.fc1 = nn.Linear(in_features=768 + 600 + 128, out_features=768 + 600 + 128) # дописать размерности
        self.fc2 = nn.Linear(in_features=768 + 600 + 128, out_features=7) # дописать размерности

    def forward(self, audio_emb, text_emb, video_emb):
        """
        audio_emb - (B, UtrLen, Hid_a) ??? поправить когда будет известен формат данных
        text_emb - (B, UtrLen, Hid_t) (B, UttLen, 768) поправить когда будет известен формат данных
        video_emb - (B, UtrLen, Hid_v) (B, UttLen, 600) поправить когда будет известен формат данных
        """

        audio_enc_out = self.audio_encoder(audio_emb) # (B, UtrLen, Hid_a)
        text_enc_out = self.text_encoder(text_emb) # (B, UtrLen, Hid_t) (B, UttLen, 768)
        video_enc_out = self.video_encoder(video_emb) # (B, UtrLen, Hid_v) (B, UttLen, 600)

        fusion_out = self.fusion_attn(audio_enc_out, text_enc_out, video_enc_out) # (B, UtrLen, ?)

        concat_out = torch.cat([audio_enc_out, fusion_out, video_enc_out], dim=-1)

        fc1_out = self.fc1(concat_out)
        fc2_out = self.fc2(fc1_out)

        return fc2_out

# Visual Feature Extractor

1. Получить эмбеддинг сцены (кадра)
2. Взять максимум по 0 размерности (чтобы остался 1 вектор размерности hidden_dim)
3. Детектировать лица на кадре
4. Получить из кропнутых лиц эмбеддинги
5. Посчитать взвешенную сумму на основе площадей боксов лиц
6. Взять максимум по 0 размерности (чтобы остался 1 вектор размерности hidden_dim)
7. Конкатенировать эмбеддинг сцены и эмбеддинг лиц

*Фейс эмбеддеры лучше поискать готовые (по совету Савченко)*

In [None]:
def adaptive_margin_triplet_loss(anchor, positive, negative, margin=1.0):
    """
    Функция для адаптивной триплетной потери с динамическим порогом
    """
    # Вычисление евклидова расстояния между парами
    dist_anchor_positive = nn.PairwiseDistance(anchor, positive, p=2)
    dist_anchor_negative = nn.PairwiseDistance(anchor, negative, p=2)
    dist_positive_negative = nn.PairwiseDistance(positive, negative, p=2)
    # Адаптивное пороговое значение
    adaptive_margin = 2 + (2 / torch.exp(4 * dist_anchor_positive)) + (2 / torch.exp(-4 * dist_anchor_negative + 4))

    # Вычисление триплетной потери с адаптивным порогом
    loss = dist_anchor_positive - (dist_anchor_negative + dist_positive_negative) / 2 + adaptive_margin
    return loss

def variance_loss(embeddings, epsilon=1e-6): # не забыть сложить по a, p, n
    variance = embeddings.var(dim=0)
    return torch.mean(1 - torch.sqrt(variance + epsilon))

# def covariance_loss(embeddings):
#     n = embeddings.size(0)
#     cov_matrix = torch.mm((embeddings - embeddings.mean(dim=0)).T, embeddings - embeddings.mean(dim=0)) / (n - 1)
#     cov_loss = torch.sum(cov_matrix ** 2) - torch.sum(torch.diag(cov_matrix) ** 2)
#     return cov_loss / embeddings.size(1)

def covariance_loss(embeddings):
    covariance = embeddings.cov()
    covariance.fill_diagonal_(0)
    return covariance.mean(dim=1) # непонятно d это размерность вектора или на единицу меньше

def combined_feature_extractor_loss(anchor, positive, negative, lambda1=1.0, lambda2=1.0, lambda3=1.0):
    triplet_loss = adaptive_margin_triplet_loss(anchor, positive, negative)

    var_loss = sum([variance_loss(embeddings) for embeddings in [anchor, positive, negative]])
    cov_loss = sum([covariance_loss(embeddings) for embeddings in [anchor, positive, negative]])
    return lambda1*triplet_loss + lambda2*var_loss + lambda3*cov_loss

# Miscellaneuos

In [None]:
import numpy as np
import mxnet as mx
from mxnet import recordio
import matplotlib.pyplot as plt

data_iter = mx.image.ImageIter(
    batch_size=4,
    data_shape=(3, 112, 112),
    path_imgrec="./Data/casia-webface/train.rec",
    path_imgidx="./Data/casia-webface/train.idx",
)
data_iter.reset()
for j in range(4):
    batch = data_iter.next()
    data = batch.data[0]
    # print(batch)
    label = batch.label[0].asnumpy()
    for i in range(4):
        ax = plt.subplot(1, 4, i + 1)
        plt.imshow(data[i].asnumpy().astype(np.uint8).transpose((1, 2, 0)))
        ax.set_title("class: " + str(label[i]))
        plt.axis("off")
    plt.show()

# # ======= Code to show single image =======#
# path_imgrec = "./Data/casia-webface/train.rec"
# path_imgidx = "./Data/casia-webface/train.idx"
# imgrec = recordio.MXIndexedRecordIO(path_imgidx, path_imgrec, "r")
# # %% 1 ~ 409623
# # for i in range(409623):
# for i in range(10):
#     header, s = recordio.unpack(imgrec.read_idx(i + 1))
#     img = mx.image.imdecode(s).asnumpy()
#     plt.imshow(img)
#     plt.title("id=" + str(i) + "label=" + str(header.label))
#     plt.pause(0.1)


In [None]:
DATA_ROOT = "./Data/casia-webface/"
INPUT_SIZE=[112, 112]
with open(os.path.join(DATA_ROOT, "property"), "r") as f:
        NUM_CLASS, h, w = [int(i) for i in f.read().split(",")]
assert h == INPUT_SIZE[0] and w == INPUT_SIZE[1]
print("Number of Training Classes: {}".format(NUM_CLASS))


# Launch test

In [4]:
# !pip install facenet_pytorch

In [3]:
import torch
import torch.nn as nn
import torchvision.models as models
from torchvision import transforms
from facenet_pytorch import MTCNN
from PIL import Image, ImageDraw
import cv2
import matplotlib.pyplot as plt


class VisualFeatureExtractor(nn.Module):
    def __init__(self, feature_dim=300):
        super(VisualFeatureExtractor, self).__init__()

        self.resnet = models.resnet18(weights='DEFAULT')
        in_features = self.resnet.fc.in_features

        self.resnet.fc = nn.Identity()

        # linear projection
        self.projector = nn.Linear(in_features, feature_dim)

    def forward(self, x):
        x = self.resnet(x)
        x = self.projector(x)
        return x


class FaceExtractor:
    def __init__(self, feature_extractor, device='cpu'):
        self.mtcnn = MTCNN()
        self.feature_extractor = feature_extractor
        self.device = device
        ##notgud duplicate
        self.img_transform = transforms.Compose([
            transforms.Resize((224, 224)),
            transforms.ToTensor(),
        ])

    def extract_faces_embedding(self, frame):

        ##notgud
        if isinstance(frame, torch.Tensor):
            frame = transforms.ToPILImage()(frame.squeeze(0))

        # find faces
        boxes, faces_confs = self.mtcnn.detect([frame])
        boxes = boxes[0]
        if faces_confs is not None and boxes is not None:
            areas = [(box[2] - box[0]) * (box[3] - box[1]) for box in boxes]
            total_area = sum(areas)
            # weight embeddings by areas of faces
            weights = [area / total_area for area in areas]

            face_embedding = torch.zeros(self.feature_extractor.projector.out_features).to(self.device)

            for i, box in enumerate(boxes):
                face = frame.crop((box[0], box[1], box[2], box[3]))
                face_tensor = self.img_transform(face).unsqueeze(0).to(self.device)
                face_emb = self.feature_extractor(face_tensor).squeeze(0)

                face_embedding += face_emb * weights[i]

            return face_embedding
        else:
            # faces wasn't found
            return torch.zeros(self.feature_extractor.projector.out_features).to(self.device)


class VideoFeatureExtractor(nn.Module):
    def __init__(self, feature_dim=300, device='cpu'):
        super(VideoFeatureExtractor, self).__init__()
        self.visual_feature_extractor = VisualFeatureExtractor(feature_dim)
        self.face_extractor = FaceExtractor(self.visual_feature_extractor, device=device)
        self.img_transform = transforms.Compose([
                transforms.Resize((224, 224)),
                transforms.ToTensor(),
            ])

    def forward(self, frames):
        scene_embeddings = []
        face_embeddings = []

        # scene, faces embeds for each frame
        for frame in frames:

            frame_emb = self.visual_feature_extractor(self.img_transform(frame).unsqueeze(0))
            scene_embeddings.append(frame_emb)

            face_emb = self.face_extractor.extract_faces_embedding(self.img_transform(frame).unsqueeze(0))
            face_embeddings.append(face_emb)

        # max pooling
        scene_embeddings = torch.stack(scene_embeddings)
        scene_embeddings = scene_embeddings.permute(1, 0, 2).squeeze(0)
        scene_embeddings = torch.max(scene_embeddings, dim=0)[0]

        face_embeddings = torch.stack(face_embeddings)
        face_embeddings = torch.max(face_embeddings, dim=0)[0]

        final_embedding = torch.cat((scene_embeddings, face_embeddings), dim=0)
        return final_embedding


def get_pil_frames(video_path, num_frames=15, device='cpu'):
    cap = cv2.VideoCapture(video_path)
    frames = []
    frame_count = 0

    while frame_count < num_frames:
        ret, frame = cap.read()
        if not ret:
            break

        frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
        frames.append(Image.fromarray(frame))
        frame_count += 1

    cap.release()

    if len(frames) < num_frames:
        raise ValueError(f"Video contains fewer than {num_frames} frames.")

    return frames


def display_frame(frame):
    plt.imshow(frame)
    plt.axis('off')
    plt.show()


feature_dim = 300
#device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device = 'cpu'
video_feature_extractor = VideoFeatureExtractor(feature_dim=feature_dim, device=device).to(device)

video_path = "./dev_splits_complete/dia0_utt1.mp4"
frames = get_pil_frames(video_path)

with torch.no_grad():
    embedding = video_feature_extractor(frames)
print("Embedding shape:", embedding.shape)


Embedding shape: torch.Size([600])
