# Naïve fusion

In [None]:
from models import *
from utils  import *

In [None]:
class MultiFusion(nn.Module):
    def __init__(self, eeg_dim, eeg_emb_dim,pos, phy_dim, phy_emb_dim, phy_feat_dim, eeg_feat_dim=5, n_class=2):
        super(MultiFusion, self ).__init__()
        
        self.eeg_attention = RegionRNN_VIG(eeg_emb_dim//2, 1, eeg_feat_dim, f_dim=eeg_feat_dim)
        self.phy_attention = SimpleRNN(phy_emb_dim, 1, phy_feat_dim, f_dim=phy_feat_dim)
        
        self.Classifier = nn.Sequential(
            nn.Linear(eeg_emb_dim, 64),
            nn.Dropout(0.25),
            nn.ReLU(True),
            nn.Linear(64, n_class),
            nn.Softmax(dim=1)
            )
        
        self.Regressor = nn.Sequential(
            nn.Linear(eeg_emb_dim, 64),
            nn.Dropout(0.25),
            nn.ReLU(True),
            nn.Linear(64, 1)
            )
        
        self.Discriminator = nn.Sequential(
            GradientReversal(),
            nn.Linear(eeg_emb_dim, 64),
            nn.Dropout(0.25),
            nn.ReLU(True),
            nn.Linear(64, 1),
            nn.Sigmoid()
            )
        
    def forward(self, x_eeg, x_phy):
        b_size = x_eeg.shape[0]
        
        spatial_eeg = x_eeg.transpose(1,2)
        feat_eeg = self.eeg_attention(spatial_eeg)
        
        spatial_phy = x_phy.transpose(1,2)
        feat_phy = self.phy_attention(spatial_phy)
        
        feat = feat_eeg*feat_phy
                
        return feat

class CatFusion(nn.Module):
    def __init__(self, eeg_dim, eeg_emb_dim,pos, phy_dim, phy_emb_dim, phy_feat_dim, eeg_feat_dim=5, n_class=2):
        super(CatFusion, self ).__init__()
        
        self.eeg_attention = RegionRNN_VIG(eeg_emb_dim//2, 1, eeg_feat_dim, f_dim=eeg_feat_dim)
        self.phy_attention = SimpleRNN(phy_emb_dim, 1, phy_feat_dim, f_dim=phy_feat_dim)      

        self.Classifier = nn.Sequential(
            nn.Linear(eeg_emb_dim+phy_emb_dim, 64),
            nn.Dropout(0.25),
            nn.ReLU(True),
            nn.Linear(64, n_class),
            nn.Softmax(dim=1)
            )
        
        self.Regressor = nn.Sequential(
            nn.Linear(eeg_emb_dim+phy_emb_dim, 64),
            nn.Dropout(0.25),
            nn.ReLU(True),
            nn.Linear(64, 1)
            )
        
        self.Discriminator = nn.Sequential(
            GradientReversal(),
            nn.Linear(eeg_emb_dim+phy_emb_dim, 64),
            nn.Dropout(0.25),
            nn.ReLU(True),
            nn.Linear(64, 1),
            nn.Sigmoid()
            )
        
        
    def forward(self, x_eeg, x_phy):
        b_size = x_eeg.shape[0]
        
        spatial_eeg = x_eeg.transpose(1,2)
        feat_eeg = self.eeg_attention(spatial_eeg)
        
        spatial_phy = x_phy.transpose(1,2)
        feat_phy = self.phy_attention(spatial_phy)
        feat = torch.cat([feat_eeg,feat_phy], axis=1)
                
        return feat

class BilFusion(nn.Module):
    def __init__(self, eeg_dim, eeg_emb_dim,pos, phy_dim, phy_emb_dim, phy_feat_dim, eeg_feat_dim=5, n_class=2):
        super(BilFusion, self ).__init__()
        
        self.eeg_attention = RegionRNN_VIG(eeg_emb_dim//2, 1, eeg_feat_dim, f_dim=eeg_feat_dim)
        self.phy_attention = SimpleRNN(phy_emb_dim, 1, phy_feat_dim, f_dim=phy_feat_dim)
        
        if torch.cuda.is_available():
            self.mcb = CompactBilinearPooling(eeg_emb_dim, phy_emb_dim, eeg_emb_dim+phy_emb_dim).cuda()
        else :
            self.mcb = CompactBilinearPooling(eeg_emb_dim, phy_emb_dim, eeg_emb_dim+phy_emb_dim)

        self.Classifier = nn.Sequential(
            nn.Linear(eeg_emb_dim+phy_emb_dim, 64),
            nn.Dropout(0.25),
            nn.ReLU(True),
            nn.Linear(64, n_class),
            nn.Softmax(dim=1)
            )
        
        self.Regressor = nn.Sequential(
            nn.Linear(eeg_emb_dim+phy_emb_dim, 64),
            nn.Dropout(0.25),
            nn.ReLU(True),
            nn.Linear(64, 1)
            )
        
        self.Discriminator = nn.Sequential(
            GradientReversal(),
            nn.Linear(eeg_emb_dim+phy_emb_dim, 64),
            nn.Dropout(0.25),
            nn.ReLU(True),
            nn.Linear(64, 1),
            nn.Sigmoid()
            )
        
    def forward(self, x_eeg, x_phy):
        b_size = x_eeg.shape[0]
        
        spatial_eeg = x_eeg.transpose(1,2)
        feat_eeg = self.eeg_attention(spatial_eeg)
        
        spatial_phy = x_phy.transpose(1,2)
        feat_phy = self.phy_attention(spatial_phy)
        feat = self.mcb(feat_eeg, feat_phy)
                
        return feat

In [None]:
'''Parameters'''
batch_size = 64
n_epoch = 500

In [None]:
def train_model(feat_eeg, feat_phy, label, participant, path_results, training_info, n_class, pos, fusion):
    EEG = EEGPhiDataset(label=label, eeg=feat_eeg, phi=feat_phy)
    Tot = {}
    for p in tqdm(np.unique(participant)):
        
        idx = np.argwhere(participant==p).squeeze()
        np.random.shuffle(idx)
        id_train = idx[:int(0.8*len(idx))]
        id_test = idx[int(0.8*len(idx)):]
        
        Test = Subset(EEG, id_test)
        #idx = np.argwhere(participant!=p).squeeze()
        #np.random.shuffle(idx)
        Train = Subset(EEG, id_train)
        
        Trainloader = DataLoader(Train, batch_size=batch_size, shuffle=False)
        Testloader = DataLoader(Test, batch_size=batch_size, shuffle=False)
        
        n_chan = feat_eeg.shape[2]
        f_dim_eeg = feat_eeg.shape[1]
        
        phy_dim = feat_phy.shape[2]
        f_dim_phy = feat_phy.shape[1]
        
        if fusion == 'cat':
            net = CatFusion(eeg_dim=n_chan, eeg_emb_dim=64, eeg_feat_dim=f_dim_eeg, pos=pos, phy_dim=phy_dim, phy_emb_dim=64, phy_feat_dim=f_dim_phy, n_class=n_class).cuda()
        elif fusion == 'mult':
            net = MultiFusion(eeg_dim=n_chan, eeg_emb_dim=64, eeg_feat_dim=f_dim_eeg, pos=pos, phy_dim=phy_dim, phy_emb_dim=64, phy_feat_dim=f_dim_phy, n_class=n_class).cuda()
        elif fusion == 'bil':
            net = BilFusion(eeg_dim=n_chan, eeg_emb_dim=64, eeg_feat_dim=f_dim_eeg, pos=pos, phy_dim=phy_dim, phy_emb_dim=64, phy_feat_dim=f_dim_phy, n_class=n_class).cuda()
            
        optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9, weight_decay=0.0001)
        
        res = {}
        pred = []
        for epoch in range(n_epoch):     
            
            running_loss = []
            
            t_cycle = iter(cycle(Testloader))
            for i, data in enumerate(Trainloader, 0):
                # get the inputs; data is a list of [inputs, labels]
                inputs_eeg_source, inputs_phy_source, labels = data
                del data
                
                data = next(t_cycle)
                inputs_eeg_test, inputs_phy_test, _ = data
                del data
                
                domain = torch.cat([torch.ones(inputs_eeg_source.shape[0]),
                                   torch.zeros(inputs_eeg_test.shape[0])]).cuda()
                inputs_eeg = torch.cat([inputs_eeg_source, inputs_eeg_test])
                inputs_phy = torch.cat([inputs_phy_source, inputs_phy_test])
                
                # zero the parameter gradients
                optimizer.zero_grad()
                # forward + backward + optimize
                #feat_ = net(inputs_eeg_source.to(torch.float32).cuda(), inputs_phy_source.to(torch.float32).cuda())
                feat_ = net(inputs_eeg.to(torch.float32).cuda(), inputs_phy.to(torch.float32).cuda())
                
                domain_pred = net.Discriminator(feat_).squeeze()
                
                loss = torch.nn.functional.binary_cross_entropy(domain_pred, domain)
                outputs = net.Regressor(feat_[:inputs_eeg_source.shape[0]])
                

                label_loss = torch.nn.functional.mse_loss(outputs.squeeze(), labels.to(torch.float).cuda())
                
                loss += label_loss
                loss.backward()
                optimizer.step()
                
                running_loss.append(label_loss.item())
            
            #running_loss = np.mean(running_loss)
            #running_rmse = rmse(np.asarray(y_label), np.asarray(y_pred))
            #running_corr = corr(np.asarray(y_label), np.asarray(y_pred))
            
            if epoch%10 == 9:
                y_pred = []
                y_label = []
                validation_loss = []
                for i, data in enumerate(Testloader, 0):
                    inputs_eeg, inputs_phy, labels = data
                    del data
                    feat_ = net(inputs_eeg.to(torch.float32).cuda(), inputs_phy.to(torch.float32).cuda())
                    outputs = net.Regressor(feat_)
                    y_pred.extend(outputs.squeeze().detach().cpu().tolist())
                    y_label.extend(labels.tolist())

                    loss = torch.nn.functional.mse_loss(outputs.squeeze(), labels.to(torch.float).cuda())
                    validation_loss.append(loss.item())

                pred.append(y_pred)
        res['pred'] = np.asarray(pred)
        res['test'] = np.asarray(y_label)
        Tot['participant_'+str(p)] = np.asarray(res)
        np.save(os.path.join(path_results, training_info), Tot)

## SEED VIG

In [None]:
'''Load File'''
label = np.load('dataset/seed_vig/label.npy')
n_class = len(np.unique(label))
participant = np.load('dataset/seed_vig/participant.npy')
elec_pos = np.load('information/seed_vig_eeg.npy')

### Multiplication

In [None]:
train_model(feat_eeg=np.load('dataset/seed_vig/feat_eeg.npy'), feat_phy=np.expand_dims(np.load('dataset/seed_iv/feat_phy.npy'),1), label=label, 
            participant=participant, path_results='res/', training_info='seed_vig_mult', n_class=n_class, pos=elec_pos, fusion='mult')

### Concatenation

In [None]:
train_model(feat_eeg=np.load('dataset/seed_vig/feat_eeg.npy'), feat_phy=np.expand_dims(np.load('dataset/seed_iv/feat_phy.npy'),1), label=label, 
            participant=participant, path_results='res/', training_info='seed_vig_cat', n_class=n_class, pos=elec_pos, fusion='cat')

### Compact Bilinear Pooling

In [None]:
train_model(feat_eeg=np.load('dataset/seed_vig/feat_eeg.npy'), feat_phy=np.expand_dims(np.load('dataset/seed_iv/feat_phy.npy'),1), label=label, 
            participant=participant, path_results='res/', training_info='seed_vig_bil', n_class=n_class, pos=elec_pos, fusion='bil')