# Single Modality Regression

In [None]:
from models import *
from utils  import *

In [None]:
class MultiAttention(nn.Module):
    def __init__(self, spatial_dep, emb_dim, feat_dim=5, eeg=False, pos=None, n_class=2):
        super(MultiAttention, self ).__init__()
        
        if eeg:
            self.spatial_attention = RegionRNN_VIG(emb_dim//2, 1, feat_dim, f_dim=feat_dim)
        else: 
            self.spatial_attention = SimpleRNN(emb_dim, 1, feat_dim, f_dim=feat_dim)
            
        self.Classifier = nn.Sequential(
            nn.Linear(emb_dim, 64),
            nn.Dropout(0.25),
            nn.ReLU(True),
            nn.Linear(64, n_class),
            nn.Softmax(dim=1)
            )
    
        self.Regressor = nn.Sequential(
            nn.Linear(emb_dim, 64),
            nn.Dropout(0.25),
            nn.ReLU(True),
            nn.Linear(64, 1)
            )
        
        self.Discriminator = nn.Sequential(
            GradientReversal(),
            nn.Linear(emb_dim, 64),
            nn.Dropout(0.25),
            nn.ReLU(True),
            nn.Linear(64, 1),
            nn.Sigmoid()
            )
        
        
    def forward(self, x):
        b_size = x.shape[0]
        
        spatial_x = x.transpose(1,2)
        feat = self.spatial_attention(spatial_x)
        
        return feat

In [None]:
'''Parameters'''
batch_size = 64
n_epoch = 500

In [None]:
def train_model(feat, label, participant, path_results, training_info, n_class, is_eeg=False, pos=None):
    EEG = EEGDataset(label=label, eeg=feat)
    Tot = {}
    for p in tqdm(np.unique(participant)):
        
        idx = np.argwhere(participant==p).squeeze()
        np.random.shuffle(idx)
        id_train = idx[:int(0.8*len(idx))]
        id_test = idx[int(0.8*len(idx)):]
        
        Test = Subset(EEG, id_test)
        #idx = np.argwhere(participant!=p).squeeze()
        #np.random.shuffle(idx)
        Train = Subset(EEG, id_train)
        
        Trainloader = DataLoader(Train, batch_size=batch_size, shuffle=False)
        Testloader = DataLoader(Test, batch_size=batch_size, shuffle=False)
        
        n_chan = feat.shape[2]
        f_dim = feat.shape[1]
        
        net = MultiAttention(spatial_dep=n_chan, emb_dim=64, feat_dim=f_dim, eeg=is_eeg, pos=pos, n_class=n_class).cuda()
        optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9, weight_decay=0.0001)
        
        res = {}
        pred = []
        
        for epoch in range(n_epoch):
                        
            running_loss = []
            
            t_cycle = iter(cycle(Testloader))
            for i, data in enumerate(Trainloader, 0):
                # get the inputs; data is a list of [inputs, labels]
                inputs_source, labels = data
                del data
                
                data = next(t_cycle)
                inputs_test, _ = data
                del data
                
                domain = torch.cat([torch.ones(inputs_source.shape[0]),
                                   torch.zeros(inputs_test.shape[0])]).cuda()
                inputs = torch.cat([inputs_source, inputs_test])
                
                # zero the parameter gradients
                optimizer.zero_grad()
                # forward + backward + optimize
                feat_ = net(inputs.to(torch.float32).cuda())
                
                domain_pred = net.Discriminator(feat_).squeeze()
                
                loss = torch.nn.functional.binary_cross_entropy(domain_pred, domain)
                
                outputs = net.Regressor(feat_[:inputs_source.shape[0]])

                label_loss = torch.nn.functional.mse_loss(outputs.squeeze(), labels.to(torch.float).cuda())
                
                loss += label_loss
                loss.backward()
                optimizer.step()
                
                running_loss.append(label_loss.item())
            
            #running_loss = np.mean(running_loss)
            #running_rmse = rmse(np.asarray(y_label), np.asarray(y_pred))
            #running_corr = corr(np.asarray(y_label), np.asarray(y_pred))
            
            if epoch%10==9:
                y_pred = []
                y_label = []
                validation_loss = []
                for i, data in enumerate(Testloader, 0):
                    inputs, labels = data
                    del data
                    feat_ = net(inputs.to(torch.float32).cuda())
                    outputs = net.Regressor(feat_)
                    y_pred.extend(outputs.squeeze().detach().cpu().tolist())
                    y_label.extend(labels.tolist())

                    loss = torch.nn.functional.mse_loss(outputs.squeeze(), labels.to(torch.float).cuda())
                    validation_loss.append(loss.item())
                pred.append(y_pred)
        res['pred'] = np.asarray(pred)
        res['test'] = np.asarray(y_label)
        Tot['participant_'+str(p)] = np.asarray(res)
        np.save(os.path.join(path_results, training_info), Tot)

## SEED VIG

In [None]:
'''Load File'''
label = np.load('dataset/seed_vig/label.npy')
n_class = len(np.unique(label))
participant = np.load('dataset/seed_vig/participant.npy')
elec_pos = np.load('information/seed_vig_eeg.npy')

### EEG

In [None]:
train_model(feat=np.load('dataset/seed_vig/feat_eeg.npy'), label=label, participant=participant, 
            path_results='res/', training_info='eeg_seed_vig', n_class=n_class, is_eeg=True, pos=elec_pos)

### Physiological

In [None]:
train_model(feat=np.expand_dims(np.load('dataset/seed_vig/feat_phy.npy'), 1), label=label, 
            participant=participant, path_results='res/', training_info='feat_seed_vig', n_class=n_class, is_eeg=False, pos=None)

In [None]:
import IPython

IPython.Application.instance().kernel.do_shutdown(True)