In [3]:
from trainer import ViT
from VITforEMGAndBone import ViTForEMGAndBone, CrossAttention
from GCNforVIdeo import GGCN, find_adjacency_matrix
from torch import nn
import torch
from einops import rearrange, repeat
import math
from loader.dataloader import MultiModalData
from torch.utils.data import Dataset, DataLoader
from util.log import Log
from sklearn.metrics import precision_score, recall_score, f1_score
from tqdm.auto import tqdm
import random
import os
import numpy as np
import wandb
import torch.nn.functional as F



wandb.login(
    key='9bce1a84793dd8652665e9c5a731d2f7775245ad',
    relogin=True
)

run = wandb.init(
    # Set the project where this run will be logged
    project="Missing_modality",
    # Track hyperparameters and run metadata
    config={
        "learning_rate": 0.01,
        "epochs": 100,
        'random_seed': 20,
        "common_dim": 64,
        "n_classes": 41,
        "batch_size": 128,
        "T": 1,
        "device":'cuda:1',
        "cl_rate": 1,
        "param_vid": 1.5,
        "param_emg": 0.8
    })


class MultiModal(nn.Module):
    def __init__(self, common_dim, n_classes):
        super(MultiModal, self).__init__()
        self.gcn = GGCN(find_adjacency_matrix(), 41,
                        [3, 9], [9, 16, 32, 64], run.config["device"], 0.0)
        self.vit = ViT(emg_size=(44100*0.2, 8), patch_height=60, num_classes=41, dim=128,
                       depth=5, mlp_dim=256, heads=8, pool='cls', dropout=0.45, emb_dropout=0.45).double()
        self.crossAtt1 = CrossAttention(63, 14112, 512).double()
        self.crossAtt2 = CrossAttention(14112, 63, 512).double()
        self.vitForEMGandBone = ViTForEMGAndBone(
            1024,  41, 512, 3, 8, 512, pool='cls', dim_head=64, dropout=0, emb_dropout=0).double()

        self.fc = nn.Linear(common_dim, n_classes).double()
        # self.act = nn.RELU() 
        self.act = nn.ReLU() 
        
        self.fc1 = nn.Linear(1536, common_dim).double()
        self.fc2 = nn.Linear(128, common_dim).double()
        self.fc3 = nn.Linear(2176, common_dim).double()
        self.fc4 = nn.Linear(common_dim, common_dim).double()
        
        self.classify = nn.Sequential(
            nn.Linear(common_dim,common_dim),
            nn.ReLU(),
            nn.Dropout(p = 0.1)
        ).double()
        self.drop = nn.Dropout(0.0)
         
    def forward(self, bones, emg):
        x1 = self.gcn(bones)  # out
        x1 = self.drop(x1)
        x2 = self.vit(emg.double())  # out
        emg1 = rearrange(emg, "b (a d) c ->b a (d c) ", a=5).double()
        bone1 = rearrange(bones, "b t n c -> b t (n c)").double()
        x3 = self.crossAtt1(bone1, emg1)
        x4 = self.crossAtt2(emg1, bone1)
        x5 = self.vitForEMGandBone(torch.concat(
            [x3, x4], dim=-1).double())  # out
        x5 = torch.concat([x1, x2, x5], dim=-1)
        # common dim
        x1 = self.fc1(x1.double())
        x2 = self.fc2(x2)
        x5 = self.fc3(x5)
        
        x1 = self.fc4(x1)
        x2 = self.fc4(x2)
        x5 = self.fc4(x5)
        
        # classify
        a1 = self.classify(x1) + x1
        a2 = self.classify(x2) + x2
        a5 = self.classify(x5) + x5

        a1 = self.fc(a1)
        a2 = self.fc(a2)
        a5 = self.fc(a5)

        return [x1, x2, x5], [a1, a2, a5]  # video,emg,multimodal

    

        
    

def super_gmc_loss(criterion,prediction, target, batch_representations, temperature, batch_size, cl_rate=2):
    joint_mod_loss_sum = 0
    for mod in range(len(batch_representations) - 1):
        # Negative pairs: everything that is not in the current joint-modality pair
        out_joint_mod = torch.cat(
            [batch_representations[-1], batch_representations[mod]], dim=0
        )
        # [2*B, 2*B]
        sim_matrix_joint_mod = torch.exp(
            torch.mm(out_joint_mod, out_joint_mod.t().contiguous()) / temperature
        )
        # Mask for remove diagonal that give trivial similarity, [2*B, 2*B]
        mask_joint_mod = (
            torch.ones_like(sim_matrix_joint_mod)
            - torch.eye(2 * batch_size, device=sim_matrix_joint_mod.device)
        ).bool()
        # Remove 2*B diagonals and reshape to [2*B, 2*B-1]
        sim_matrix_joint_mod = sim_matrix_joint_mod.masked_select(
            mask_joint_mod
        ).view(2 * batch_size, -1)

        # Positive pairs: cosine loss joint-modality
        pos_sim_joint_mod = torch.exp(
            torch.sum(
                batch_representations[-1] * batch_representations[mod], dim=-1
            )
            / temperature
        )
        # [2*B]
        pos_sim_joint_mod = torch.cat([pos_sim_joint_mod, pos_sim_joint_mod], dim=0)
        loss_joint_mod = -torch.log(
            pos_sim_joint_mod / sim_matrix_joint_mod.sum(dim=-1)
        )
        joint_mod_loss_sum += loss_joint_mod
        
        # print(torch.mean(loss_joint_mod).item())

    supervised_loss =  criterion(prediction[-1], target)
    joint_mod_loss_sum *= cl_rate
    
    L_GCM = torch.mean(joint_mod_loss_sum).item()
    L_classify = torch.mean(supervised_loss).item()
    # print(L_GCM)
    # print(L_classify)
    

    loss = torch.mean(joint_mod_loss_sum + supervised_loss)
    # loss = torch.mean(supervised_loss)
    
    return loss,L_GCM,L_classify

def train(train_loader, model, criterion, optimizer, device, T, loss_log):
    running_loss = 0
    loss_gcm = 0
    loss_classify = 0
    model.train()

    for videos, labels, emgs in tqdm(train_loader):

        videos = videos.to(device)
        labels = labels.to(device)
        emgs = emgs.to(device).double()
       
        # forward
        outputs, outputs2 = model(videos, emgs)
            
        # backward
        loss,L_GMC,classification_loss = super_gmc_loss(criterion,outputs2,labels,outputs,run.config['T'],len(labels), run.config['cl_rate'])
    
        running_loss += loss.item()
        loss_gcm += L_GMC
        loss_classify += classification_loss
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    epoch_loss = running_loss / (len(train_loader))
    loss_gcm = loss_gcm / (len(train_loader))
    loss_classify = loss_classify / (len(train_loader))
    loss_log[0].append(loss_gcm)
    loss_log[1].append(loss_classify)
    loss_log[2].append(epoch_loss)

    return model, epoch_loss, optimizer, loss_log


def validate(valid_loader, model, criterion, device, T, val_loss_log):
    model.eval()
    running_loss = 0
    loss_gcm = 0
    loss_classify = 0

    for videos, labels, emgs in tqdm(valid_loader):

        videos = videos.to(device)
        labels = labels.to(device)
        emgs = emgs.to(device).double()
        L_GMC = 0
        # forward
        outputs, outputs2 = model(videos, emgs)

        loss,L_GMC,classification_loss = super_gmc_loss(criterion,outputs2,labels,outputs,run.config['T'],len(labels),run.config['cl_rate'])
    
        running_loss += loss.item()
        loss_gcm += L_GMC
        loss_classify += classification_loss

    epoch_loss = running_loss / (len(valid_loader))
    loss_gcm = loss_gcm / (len(valid_loader))
    loss_classify = loss_classify / (len(valid_loader))
    val_loss_log[0].append(loss_gcm)
    val_loss_log[1].append(loss_classify)
    val_loss_log[2].append(epoch_loss)
    return model, epoch_loss, val_loss_log


def get_accuracy(model, data_loader, device, modality='multimodal'):
    correct = 0
    total = 0
    predicted_labels = []
    truth_labels = []

    model.eval()
    for videos, labels, emgs in data_loader:
        videos = videos.to(device)
        labels = labels.to(device)
        emgs = emgs.to(device).double()

        # forward
        _, outputs = model(videos, emgs)
        if modality == 'multimodal':
            predicted = torch.argmax(torch.softmax(outputs[-1], 1), 1)
        elif modality == 'emg':
            
            predicted = torch.argmax(torch.softmax(outputs[1], 1), 1)
        else:
            predicted = torch.argmax(torch.softmax(outputs[0], 1), 1)

        total += labels.shape[0]
        correct += (predicted == labels).sum().item()
        predicted_labels.extend(predicted)
        truth_labels.extend(labels)

    f1_micro = f1_score(torch.tensor(truth_labels).cpu().data.numpy(
    ), torch.tensor(predicted_labels).cpu().data.numpy(), average='macro')
    precision_score_f1 = precision_score(torch.tensor(truth_labels).cpu().data.numpy(
    ), torch.tensor(predicted_labels).cpu().data.numpy(), average='macro')
    recall_score_f1 = recall_score(torch.tensor(truth_labels).cpu().data.numpy(
    ), torch.tensor(predicted_labels).cpu().data.numpy(), average='macro')

    return correct/total, f1_micro, precision_score_f1, recall_score_f1,predicted_labels,truth_labels


def seed_everything(seed):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = True


seed_everything(run.config["random_seed"])

trainset = MultiModalData("data/new_data/new_train_files.pkl")
testset = MultiModalData("data/new_data/new_test_files.pkl")
valset = MultiModalData("data/new_data/new_val_files.pkl")

train_loader = DataLoader(trainset, batch_size=run.config['batch_size'],
                          drop_last=False, num_workers=3, shuffle=True)
valid_loader = DataLoader(valset, batch_size=run.config['batch_size'],
                          drop_last=False, num_workers=3 )
test_loader = DataLoader(testset, batch_size=run.config['batch_size'],
                         drop_last=False, num_workers=3)


device = run.config['device']
model = MultiModal(
    common_dim=run.config['common_dim'], n_classes=run.config['n_classes']).to(device)
model.load_state_dict(torch.load("log/Missing_modal2/best_model59.pth"))
criterion = nn.CrossEntropyLoss()

val_acc, f1_score_micro, precision_score_micro, recall_score_micro,predicted_labels,truth_labels = get_accuracy(
        model, test_loader, device,'vid')
print(val_acc, f1_score_micro, precision_score_micro, recall_score_micro)
wandb.run.finish()

[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /home/aiotlabws/.netrc


47353
13801
14188
0.8973987392217955 0.8964846746602742 0.9042700120883858 0.8964676041890346




In [7]:
import torch
from loader.dataloader import SpectrogramData
import glob

In [8]:
data = SpectrogramData("data/new_data/emg_train.pkl")

10464


In [9]:
data[0]

(0,
 tensor(40),
 tensor([[[-6.8987e-05+0.0000e+00j, -7.8284e-05+0.0000e+00j,
           -3.1138e-05+0.0000e+00j,  ...,
            1.2430e-04+0.0000e+00j, -9.0423e-05+0.0000e+00j,
           -7.7840e-05+0.0000e+00j],
          [ 4.7402e-05-3.8123e-05j,  2.7431e-05-1.1573e-05j,
           -5.3318e-06+5.9286e-05j,  ...,
           -4.3801e-04-4.0099e-04j, -5.9275e-05+2.5222e-04j,
            1.0619e-04+4.1414e-05j],
          [-1.6749e-05+4.1634e-05j,  7.0033e-06+7.8769e-06j,
            5.7369e-06-8.2540e-05j,  ...,
            5.0735e-04+8.0431e-04j,  3.5974e-05-1.3661e-04j,
           -3.3681e-05-1.4023e-04j],
          ...,
          [ 4.0287e-07+3.6197e-08j, -5.2112e-07+1.4716e-07j,
           -5.6393e-07-1.3571e-07j,  ...,
            3.9347e-07+6.7824e-07j,  8.8343e-07+4.6096e-07j,
           -1.5179e-06+4.4456e-07j],
          [-4.1944e-07+0.0000e+00j,  3.1974e-07+0.0000e+00j,
            2.5572e-07+0.0000e+00j,  ...,
           -4.9628e-07+0.0000e+00j, -1.3422e-06+0.0000e+00j,


In [7]:
from trainer import ViT
from VITforEMGAndBone import ViTForEMGAndBone, CrossAttention
from GCNforVIdeo import GGCN, find_adjacency_matrix
from torch import nn
import torch
from einops import rearrange, repeat
import math
from loader.dataloader import MultiModalData
from torch.utils.data import Dataset, DataLoader
from util.log import Log
from sklearn.metrics import precision_score, recall_score, f1_score
from tqdm.auto import tqdm
import random
import os
import numpy as np
import wandb
import torch.nn.functional as F

In [2]:
class CrossAttentionFor3Modalites(nn.Module):
    def __init__(self,dim1,dim2,out_dim):
        super(CrossAttentionFor3Modalites, self).__init__()
        self.crossAtt1 = CrossAttention(dim1, dim2, out_dim).double()
        self.crossAtt2 = CrossAttention(dim2, dim1, out_dim).double()
        
    def forward(self,modality1,modality2):
        x1 =  self.crossAtt1(modality1,modality2)
        x2 =  self.crossAtt2(modality2,modality1)
        return torch.concat([x1,x2],dim = -1)
    


In [5]:
modality1 = torch.rand((5,5,21,3))
modality2 = torch.rand((5,8820,8))
modality3 = torch.rand((5,8,130,70))
modality1 = rearrange(modality1,"a b c d-> a b (c d)").double()
modality2 = rearrange(modality2, "b (a d) c ->b a (d c) ", a=5).double()
modality3 = rearrange(rearrange(modality3, "b c a d ->b d (c a)"),"b (a e) f-> b a (e f)",a = 5).double()
# modality3 = nn.Linear(modality3.shape[-1],1024)(modality3)


# md1 = CrossAttentionFor3Modalites(63,14112,256).double()
# md2 = CrossAttentionFor3Modalites(63,72800,256).double()
# md3 = CrossAttentionFor3Modalites(72800,14112,256).double()

# print(md1(modality1,modality2).shape,md2(modality1,modality3).shape,md3(modality3,modality2).shape)
print(modality3.shape)

torch.Size([5, 5, 14560])


In [9]:
import glob
data = glob.glob("data/new_data/new_spectrogram/*")
for url in data:
    spectrogram, label = torch.load(url.replace("emg_data","spectrogram"))
    if spectrogram.shape[1]> 130:
        print(url)

data/new_data/new_spectrogram/PhongTrang_20230701_02_P_12.pkl
