# Attention Fusion

In [1]:
from models import *
from utils import *

In [2]:
class AttentionFusion(nn.Module):
    def __init__(self, in_ffn=64, out_ffn=64, drop_ffn=0.15, att_dim=64, drop_att=0.25, n_head=2):
        super(AttentionFusion, self).__init__()

        self.mhatt1 = MultiHeadAttention(hidden_size=att_dim, dropout=drop_att, n_head=n_head)
        self.mhatt2 = MultiHeadAttention(hidden_size=att_dim, dropout=drop_att, n_head=n_head)
        self.ffn = FC(in_size=in_ffn, out_size=out_ffn, dropout_r=drop_ffn)

        self.dropout1 = nn.Dropout(0.25)
        self.norm1 = LayerNorm(in_ffn)

        self.dropout2 = nn.Dropout(0.25)
        self.norm2 = LayerNorm(out_ffn)

        self.dropout3 = nn.Dropout(0.25)
        self.norm3 = LayerNorm(att_dim)

    def forward(self, x, y,):
        x = self.norm1(x + self.dropout1(
            self.mhatt1(v=x, k=x, q=y)
        ))

        x = self.norm3(x + self.dropout3(
            self.ffn(x)
        ))

        return x


class FusionNet(nn.Module):
    def __init__(self, eeg_dim, eeg_emb_dim,pos, phy_dim, phy_emb_dim, phy_feat_dim, eeg_feat_dim=5, n_class=2):
        super(FusionNet, self ).__init__()

        self.eeg_attention = RegionRNN_DEAP(eeg_emb_dim//2, 1, eeg_feat_dim, f_dim=eeg_feat_dim)
        self.phy_attention = SimpleRNN_DEAP(phy_emb_dim, 1, phy_feat_dim, f_dim=phy_feat_dim)
        
        self.AttentionFusion = AttentionFusion(in_ffn=eeg_emb_dim, out_ffn=eeg_emb_dim, att_dim=phy_emb_dim)

        self.Classifier = nn.Sequential(
            nn.Linear(eeg_emb_dim, 64),
            nn.Dropout(0.25),
            nn.ReLU(True),
            nn.Linear(64, n_class),
            nn.Softmax(dim=1)
            )
         
        self.Discriminator = nn.Sequential(
            GradientReversal(),
            nn.Linear(eeg_emb_dim, 64),
            nn.Dropout(0.25),
            nn.ReLU(True),
            nn.Linear(64, 1),
            nn.Sigmoid()
        )
        
    def forward(self, x_eeg, x_phy):
        b_size = x_eeg.shape[0]
        
        spatial_eeg = x_eeg.transpose(1,2)
        feat_eeg = self.eeg_attention(spatial_eeg)
        
        spatial_phy = x_phy.transpose(1,2)
        feat_phy = self.phy_attention(spatial_phy)
        
        feat = self.AttentionFusion(feat_eeg, feat_phy)
                
        return feat

In [3]:
'''Parameters'''
batch_size = 64
n_epoch = 1000

In [4]:
def train_model(feat_eeg, feat_phy, label, participant, path_results, training_info, n_class, pos):
    EEG = EEGPhiDataset(label=label, eeg=feat_eeg, phi=feat_phy)
    Tot = {}
    session = np.load('dataset/seed_iv/session.npy')
    for p in tqdm(np.unique(participant)):
        idx = np.argwhere(participant==p).squeeze()
        np.random.shuffle(idx)
        id_train = idx[:int(0.8*len(idx))]
        id_test = idx[int(0.8*len(idx)):]
        
        np.random.shuffle(id_train)
        np.random.shuffle(id_test)
        
        Test = Subset(EEG, id_test)
        Train = Subset(EEG, id_train)
        
        Trainloader = DataLoader(Train, batch_size=batch_size, shuffle=False)
        Testloader = DataLoader(Test, batch_size=batch_size, shuffle=False)
        
        n_chan = feat_eeg.shape[2]
        f_dim_eeg = feat_eeg.shape[1]
        
        phy_dim = feat_phy.shape[2]
        f_dim_phy = feat_phy.shape[1]
        
        net = FusionNet(eeg_dim=n_chan, eeg_emb_dim=64, eeg_feat_dim=f_dim_eeg, pos=pos, phy_dim=phy_dim, phy_emb_dim=64, phy_feat_dim=f_dim_phy, n_class=n_class).cuda()
        
        optimizer = optim.SGD(net.parameters(), lr=0.0005, momentum=0.9, weight_decay=0.005)
        
        res = []
        for epoch in range(n_epoch):
            
            running_loss = 0.0
            evaluation = []
            t_cycle = iter(cycle(Testloader))
            for i, data in enumerate(Trainloader, 0):
                # get the inputs; data is a list of [inputs, labels]
                inputs_eeg_source, inputs_phy_source, labels = data
                del data
                data = next(t_cycle)
                inputs_eeg_test, inputs_phy_test, _ = data
                del data
                
                domain = torch.cat([torch.ones(inputs_eeg_source.shape[0]),
                                   torch.zeros(inputs_eeg_test.shape[0])]).cuda()
                inputs_eeg = torch.cat([inputs_eeg_source, inputs_eeg_test])
                inputs_phy = torch.cat([inputs_phy_source, inputs_phy_test])
                
                # zero the parameter gradients
                optimizer.zero_grad()
                # forward + backward + optimize
                feat_ = net(inputs_eeg.to(torch.float32).cuda(), inputs_phy.to(torch.float32).cuda())
                
                domain_pred = net.Discriminator(feat_).squeeze()
                
                feat_ = net(inputs_eeg.to(torch.float32).cuda(), inputs_phy.to(torch.float32).cuda())
                loss = torch.nn.functional.binary_cross_entropy(domain_pred, domain)
                outputs = net.Classifier(feat_[:inputs_eeg_source.shape[0]])

                label_loss = torch.nn.functional.cross_entropy(outputs, labels.to(torch.long).cuda())
                loss = label_loss
                loss.backward()
                optimizer.step()

                _, predicted = torch.max(outputs, 1)
                num_of_true = torch.sum(predicted.detach().cpu()==labels).numpy()
                mean = num_of_true/labels.shape[0]
                running_loss = label_loss.item()
                evaluation.append(mean)
            running_loss = running_loss/(i+1)
            running_acc = sum(evaluation)/len(evaluation)
            
            validation_loss = 0.0
            validation_acc = 0.0
            evaluation = []
            for i, data in enumerate(Testloader, 0):
                inputs_eeg, inputs_phy, labels = data
                del data
                feat_ = net(inputs_eeg.to(torch.float32).cuda(), inputs_phy.to(torch.float32).cuda())
                outputs = net.Classifier(feat_)
                loss = torch.nn.functional.cross_entropy(outputs, labels.cuda())
                validation_loss += loss.item()
                
                _, predicted = torch.max(outputs, 1)
                num_of_true = torch.sum(predicted.detach().cpu()==labels).numpy()
                evaluation.append(num_of_true/labels.shape[0])
            validation_loss = validation_loss/(i+1)
            validation_acc = sum(evaluation)/len(evaluation)
            
            res.append((running_loss, running_acc, validation_loss, validation_acc))
        Tot['participant_'+str(p)] = np.asarray(res)
        np.save(os.path.join(path_results, training_info), Tot)

## PhyDAA Dataset

In [None]:
'''Load File'''
label = np.load('dataset/phydaa/label.npy').astype(int)
n_class = len(np.unique(label))
participant = np.load('dataset/phydaa/participant.npy')
elec_pos = np.load('information/phydaa_eeg.npy')

In [None]:
train_model(feat_eeg=np.load('dataset/phydaa/feat_eeg.npy'), feat_phy=np.expand_dims(np.load('dataset/phydaa/feat_phy.npy'),1), label=label, 
            participant=participant, path_results='res/ind/', training_info='phydaa_attention', n_class=n_class, pos=elec_pos)

## SEED IV

In [None]:
'''Load File'''
label = np.load('dataset/seed_iv//label.npy').astype(int)
n_class = len(np.unique(label))
participant = np.load('dataset/seed_iv/participant.npy')
elec_pos = np.load('information/seed_iv_eeg.npy')

In [None]:
phy_feat = np.load('dataset/seed_iv/feat_phy.npy')
phy_feat[:, :12] = (phy_feat[:, :12] - phy_feat[:, :12].min())/np.max(phy_feat[:, :12] - phy_feat[:, :12].min()) 
phy_feat[:, 12:16] = (phy_feat[:, 12:16] - phy_feat[:, 12:16].min())/np.max(phy_feat[:, 12:16] - phy_feat[:, 12:16].min()) 
phy_feat[:, 16:18] = (phy_feat[:, 16:18] - phy_feat[:, 16:18].min())/np.max(phy_feat[:, 16:18] - phy_feat[:, 16:18].min()) 
phy_feat[:, 18:22] = (phy_feat[:, 18:22] - phy_feat[:, 18:22].min())/np.max(phy_feat[:, 18:22] - phy_feat[:, 18:22].min())

train_model(feat_eeg=np.load('dataset/seed_iv/feat_eeg.npy'), feat_phy=np.expand_dims(phy_feat,1)[:, :, :22], label=label, 
            participant=participant, path_results='res/dep/', training_info='seed_iv_attention_rnn', n_class=n_class, pos=elec_pos)

## DEAP

In [5]:
'''Load File'''
label = np.load('dataset/deap/label.npy') - 4.5
label = (label[:, 0]>0).astype(int) + 2*(label[:, 1]>0).astype(int)
n_class = len(np.unique(label))
participant = np.load('dataset/deap/participant.npy')

In [6]:
train_model(feat_eeg=np.load('dataset/deap/feat_eeg.npy'), feat_phy=np.load('dataset/deap/feat_phy.npy'), label=label, 
            participant=participant, path_results='res/dep/', training_info='deap_attention_rnn', n_class=n_class, pos=None)

100%|██████████| 32/32 [2:51:23<00:00, 321.37s/it]  


In [7]:
import IPython

IPython.Application.instance().kernel.do_shutdown(True)

{'status': 'ok', 'restart': True}