<a href="https://colab.research.google.com/github/Guo-Weiqiang/Master-Project/blob/main/baseline(EEGNet).ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>


# Loss Part




In [None]:
def MMD(source, target):
	"""
	mmd loss(linear)
	:param source: torch tensor: source data (Ds) with dimensions DxNs
	:param target: torch tensor: target data (Dt) with dimensons DxNt
	"""
	delta = source - target
	loss = torch.mean(torch.mm(delta, torch.transpose(delta, 0, 1)))

	return loss


# def guassian_kernel(source, target, kernel_mul=2.0, kernel_num=5, fix_sigma=None):
#     n_samples = int(source.size()[0])+int(target.size()[0])
#     total = torch.cat([source, target], dim=0)
#     total0 = total.unsqueeze(0).expand(int(total.size(0)), int(total.size(0)), int(total.size(1)))
#     total1 = total.unsqueeze(1).expand(int(total.size(0)), int(total.size(0)), int(total.size(1)))
#     L2_distance = ((total0-total1)**2).sum(2)
#     if fix_sigma:
#         bandwidth = fix_sigma
#     else:
#         bandwidth = torch.sum(L2_distance.data) / (n_samples**2-n_samples)
#     bandwidth /= kernel_mul ** (kernel_num // 2)
#     bandwidth_list = [bandwidth * (kernel_mul**i) for i in range(kernel_num)]
#     kernel_val = [torch.exp(-L2_distance / bandwidth_temp) for bandwidth_temp in bandwidth_list]
#     print(sum(kernel_val))
#     return sum(kernel_val)#/len(kernel_val)


# def MK_MMD(source, target, kernel_mul=2.0, kernel_num=5, fix_sigma=None):
#     batch_size = int(source.size()[0])
#     kernels = guassian_kernel(source, target,
#         kernel_mul=kernel_mul, kernel_num=kernel_num, fix_sigma=fix_sigma)

#     loss1 = 0
#     for s1 in range(batch_size):
#         for s2 in range(s1+1, batch_size):
#             t1, t2 = s1+batch_size, s2+batch_size
#             loss1 += kernels[s1, s2] + kernels[t1, t2]
#     loss1 = loss1 / float(batch_size * (batch_size - 1) / 2)

#     loss2 = 0
#     for s1 in range(batch_size):
#         for s2 in range(batch_size):
#             t1, t2 = s1+batch_size, s2+batch_size
#             loss2 -= kernels[s1, t2] + kernels[s2, t1]
#     loss2 = loss2 / float(batch_size * batch_size)
#     return loss1 + loss2

def gaussian_kernel(source, target, kernel_mul=2.0, kernel_num=5, fix_sigma=None):
    # The simplest version: K = \beta * k where \beta=1
    n_samples_source = source.size(0)
    n_samples_target = target.size(0)
    total = torch.cat([source, target], dim=0)

    # Calculate the L2 distance matrix efficiently using matrix operations.
    total_xx = torch.sum(total * total, dim=1, keepdim=True)
    L2_distance = total_xx - 2.0 * torch.matmul(total, total.t()) + total_xx.t()
    if fix_sigma:
        bandwidth = fix_sigma
    else:
        # Set the average value of distance matrix as the bandwidth
        bandwidth = torch.sum(L2_distance.data) / (n_samples_source * n_samples_target - n_samples_source)

    bandwidth /= kernel_mul ** (kernel_num // 2) #
    # bandwidth range [2^{-4}*sigma, 2^4*sigma]
    bandwidth_list = [bandwidth * (kernel_mul ** (2 * i)) for i in range(kernel_num)]
    kernel_val = [torch.exp(-L2_distance / bandwidth_temp) for bandwidth_temp in bandwidth_list]

    return sum(kernel_val) # final gaussian kernel matrix

def MK_MMD(source, target, kernel_mul=2.0, kernel_num=5, fix_sigma=None):
    batch_size = source.size(0)
    # n_t = target.size(0)
    kernels = gaussian_kernel(source, target,
        kernel_mul=kernel_mul, kernel_num=kernel_num, fix_sigma=fix_sigma)

    # Efficiently calculate loss components using vectorized operations.
    # SS = torch.triu(kernels[:n_s, :n_s], diagonal=1).sum() / float(n_s * n_s - n_s)
    # TT = torch.triu(kernels[-n_t:, -n_t:], diagonal=1).sum() / float(n_t * n_t - n_t)
    # ST = -kernels[:n_s, -n_t:].sum() / float(n_s * n_t)
    # TS = -kernels[-n_t:, :n_s].sum() / float(n_s * n_t)

    # loss = torch.abs(SS + TT - ST - TS)

    # return loss
    loss1 = torch.triu(kernels[:batch_size, :batch_size], diagonal=1).sum() + torch.triu(kernels[batch_size:, batch_size:], diagonal=1).sum()
    loss1 = loss1 / float(batch_size * (batch_size - 1) / 2)

    loss2 = -kernels[:batch_size, batch_size:].sum() - kernels[batch_size:, :batch_size].sum()
    loss2 = loss2 / float(batch_size * batch_size)

    return loss1 + loss2

# 1. Model Part

In [None]:
import torch
from torch.autograd import Variable
import torchvision.models as models
import torch.nn as nn
import matplotlib.pyplot as plt
from torch.utils.data import TensorDataset,DataLoader
import numpy as np
import torch.optim as optim
import pandas as pd
from torchsummary import summary
import os

class EEGNet_ReLU(torch.nn.Module):
    def __init__(self, n_output):
        super(EEGNet_ReLU, self).__init__()
        self.firstConv = nn.Sequential(
            nn.Conv2d(1, 16, kernel_size=(1,51), stride=(1,1), padding=(0,25),bias=False),
            nn.BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        )
        self.depthwiseConv = nn.Sequential(
            nn.Conv2d(16, 32, kernel_size=(2,1), stride=(1,1), groups=8,bias=False),
            nn.BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True),
            nn.ReLU(),
            nn.AvgPool2d(kernel_size=(1,4), stride=(1,4),padding=0),
            nn.Dropout(p=0.35)
        )
        self.separableConv = nn.Sequential(
            nn.Conv2d(32, 32, kernel_size=(1,15), stride=(1,1), padding=(0,7),bias=False),
            nn.BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True),
            nn.ReLU(),
            nn.AvgPool2d(kernel_size=(1,8), stride=(1,8),padding=0),
            nn.Dropout(p=0.35)
        )
        self.classify = nn.Sequential(
            nn.Flatten(),
            # nn.Linear(736, 256),
            # nn.ReLU(inplace=True),
            # nn.Linear(256, 256),
            # nn.ReLU(inplace=True),
            # nn.Linear(256, n_output, bias=True)

            # original classifier
            nn.Linear(736, n_output, bias=True)
        )

    def forward(self, x):
        out = self.firstConv(x)
        out = self.depthwiseConv(out)
        out = self.separableConv(out)
        out = self.classify(out)
        return out

class EEGNet_DDC(nn.Module):
    """
    Deep domain confusion network as defined in the paper:
    https://arxiv.org/abs/1412.3474
        :param num_classes: int
    """
    def __init__(self, num_classes):
        super(EEGNet_DDC, self).__init__()
        self.firstConv = nn.Sequential(
            nn.Conv2d(1, 16, kernel_size=(1,51), stride=(1,1), padding=(0,25),bias=False),
            nn.BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        )
        self.depthwiseConv = nn.Sequential(
            nn.Conv2d(16, 32, kernel_size=(2,1), stride=(1,1), groups=8,bias=False),
            nn.BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True),
            nn.ReLU(),
            nn.AvgPool2d(kernel_size=(1,4), stride=(1,4),padding=0),
            nn.Dropout(p=0.35)
        )
        self.separableConv = nn.Sequential(
            nn.Conv2d(32, 32, kernel_size=(1,15), stride=(1,1), padding=(0,7),bias=False),
            nn.BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True),
            nn.ReLU(),
            nn.AvgPool2d(kernel_size=(1,8), stride=(1,8),padding=0),
            nn.Dropout(p=0.35)
        )

        # Ref. bottleneck in the paper
        # Our intuition is that a lower dimensional layer can be used to regularize the training of the source classiﬁer and prevent overﬁtting to the particular nuances of the source distribution
        self.bottleneck = nn.Sequential(
            nn.Flatten(),
            nn.Linear(736, 256),
            nn.ReLU(inplace=True),
            nn.Linear(256, 256),
            nn.ReLU(inplace=True),
        )

        self.classifier = nn.Sequential(
            # nn.Flatten(),
            nn.Linear(256, num_classes, bias=True)
        )

        # # fc8 activation (final_classifier)
        # # self.fc8 = nn.Linear(736, num_classes)
        # self.classifier = nn.Sequential(
        #     nn.Linear(256, num_classes, bias=True)
        # )

    def forward(self, source, target):
        source = self.firstConv(source)
        source = self.depthwiseConv(source)
        source = self.separableConv(source)
        source = self.bottleneck(source)
        # print(source.shape)

        mmd_loss = 0
        if self.training:
            target = self.firstConv(target)
            target = self.depthwiseConv(target)
            target = self.separableConv(target)
            target = self.bottleneck(target)
            # print(target.shape)
            mmd_loss += MMD(source, target)

        result = self.classifier(source)

        return result, mmd_loss


class EEGNet_DAN(nn.Module):
    """
    Deep domain confusion network as defined in the paper:
    https://arxiv.org/abs/1412.3474
        :param num_classes: int
    """
    def __init__(self, num_classes):
        super(EEGNet_DAN, self).__init__()
        self.firstConv = nn.Sequential(
            nn.Conv2d(1, 16, kernel_size=(1,51), stride=(1,1), padding=(0,25),bias=False),
            nn.BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        )
        self.depthwiseConv = nn.Sequential(
            nn.Conv2d(16, 32, kernel_size=(2,1), stride=(1,1), groups=8,bias=False),
            nn.BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True),
            nn.ReLU(),
            nn.AvgPool2d(kernel_size=(1,4), stride=(1,4),padding=0),
            nn.Dropout(p=0.35)
        )
        self.separableConv = nn.Sequential(
            nn.Conv2d(32, 32, kernel_size=(1,15), stride=(1,1), padding=(0,7),bias=False),
            nn.BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True),
            nn.ReLU(),
            nn.AvgPool2d(kernel_size=(1,8), stride=(1,8),padding=0),
            nn.Dropout(p=0.35)
        )

        # Ref. bottleneck in the paper
        # Our intuition is that a lower dimensional layer can be used to regularize the training of the source classiﬁer and prevent overﬁtting to the particular nuances of the source distribution
        self.bottleneck1 = nn.Sequential(
            nn.Flatten(),
            nn.Linear(736, 256),
            nn.ReLU(inplace=True)

            # nn.Linear(256, 256),
            # nn.ReLU(inplace=True)
        )
        self.bottleneck2 = nn.Sequential(
            nn.Linear(256, 256),
            nn.ReLU(inplace=True)
        )

        self.classifier = nn.Sequential(
            # nn.Flatten(),
            nn.Linear(256, num_classes, bias=True)
        )

        # # fc8 activation (final_classifier)
        # # self.fc8 = nn.Linear(736, num_classes)
        # self.classifier = nn.Sequential(
        #     nn.Linear(256, num_classes, bias=True)
        # )

    def forward(self, source, target):
        source = self.firstConv(source)
        source = self.depthwiseConv(source)
        source = self.separableConv(source)
        f1_s = self.bottleneck1(source)
        f2_s = self.bottleneck2(f1_s)
        # print(source.shape)

        # if self.training:
        target = self.firstConv(target)
        target = self.depthwiseConv(target)
        target = self.separableConv(target)
        f1_t = self.bottleneck1(target)
        f2_t = self.bottleneck2(f1_t)
        mmd_loss1 = MK_MMD(f1_s, f1_t)
        mmd_loss2 = MK_MMD(f2_s, f2_t)

        result_s = self.classifier(f2_s)
        result_t = self.classifier(f2_t)
        mmd_loss3 = MK_MMD(result_s, result_t)
        mmd_loss = 1/3 * mmd_loss1 + 1/6 * mmd_loss2 + 1/6 * mmd_loss3

        return result_s, mmd_loss


# 2. Loss Part

# 3. Evaluation Part

## 3.1 Data Import

In [None]:
import numpy as np


def read_bci_data():
    """
    two subjects: S4b, X11b
    The experiment consists of 3 sessions for each subject. Each session consists of 4 to 9 runs
    """
    S4b_train = np.load('drive/MyDrive/EEGNet/S4b_train.npz')
    X11b_train = np.load('drive/MyDrive/EEGNet/X11b_train.npz')
    S4b_test = np.load('drive/MyDrive/EEGNet/S4b_test.npz')
    X11b_test = np.load('drive/MyDrive/EEGNet/X11b_test.npz')

    # source_data = S4b_train['signal']
    # source_label = S4b_train['label']
    # target_data = S4b_test['signal']
    # target_label = S4b_test['label']
    # target_data = X11b_train['signal']
    # target_label = X11b_train['label']
    # source_data = np.concatenate((S4b_train['signal'], X11b_train['signal']), axis=0)
    # source_label = np.concatenate((S4b_train['label'], X11b_train['label']), axis=0)
    # target_data = np.concatenate((S4b_test['signal'], X11b_test['signal']), axis=0)
    # target_label = np.concatenate((S4b_test['label'], X11b_test['label']), axis=0)

    source_data = np.concatenate((S4b_train['signal'], S4b_test['signal']), axis=0)
    source_label = np.concatenate((S4b_train['label'], S4b_test['label']), axis=0)
    target_data = np.concatenate((X11b_train['signal'], X11b_test['signal']), axis=0)
    target_label = np.concatenate((X11b_train['label'], X11b_test['label']), axis=0)

    source_label = source_label - 1
    target_label = target_label - 1
    source_data = np.transpose(np.expand_dims(source_data, axis=1), (0, 1, 3, 2))
    target_data = np.transpose(np.expand_dims(target_data, axis=1), (0, 1, 3, 2))

    mask = np.where(np.isnan(source_data))
    source_data[mask] = np.nanmean(source_data)

    mask = np.where(np.isnan(target_data))
    target_data[mask] = np.nanmean(target_data)

    # print(train_data.shape, train_label.shape, test_data.shape, test_label.shape)

    return source_data, source_label, target_data, target_label


source_data, source_label, target_data, target_label = read_bci_data()
print(source_data.shape)
print(source_label.shape)
print(target_data.shape)
print(target_label.shape)


(1080, 1, 2, 750)
(1080,)
(1080, 1, 2, 750)
(1080,)


## 3.2 without DA

In [None]:

def testing(x_test,y_test,model,device):

    # model.load_state_dict(torch.load(filepath))
    model.eval()
    with torch.no_grad():
        model.cuda(0)
        n = x_test.shape[0]

        x_test = x_test.astype("float32")
        y_test = y_test.astype("float32").reshape(y_test.shape[0],)

        x_test, y_test = Variable(torch.from_numpy(x_test)),Variable(torch.from_numpy(y_test))

        x_test,y_test = x_test.to(device),y_test.to(device)
        y_pred_test = model(x_test)

        correct_test = (torch.max(y_pred_test,1)[1]==y_test).sum().item()
        test_accuracy = correct_test/n
        # print("testing accuracy:",correct/n)

    return test_accuracy


def train(source_data, source_label, target_data, target_label, epochs=500, lr=1e-3):
    torch.manual_seed(1)    # reproducible

    max_training_accuracy = 0
    max_test_accuracy = 0
    device = torch.device("cuda:0")

    source_data = source_data.astype("float32")
    source_label = source_label.astype("float32").reshape(source_label.shape[0],)


    x, y = Variable(torch.from_numpy(source_data)), Variable(torch.from_numpy(source_label))
    y = torch.tensor(y, dtype=torch.long)

    # target_data, target_label = Variable(torch.from_numpy(target_data)), Variable(torch.from_numpy(target_label))
    # target_label = torch.tensor(target_label, dtype=torch.long)

    model = EEGNet_ReLU(n_output=2)
    print(model)
    criterion = nn.CrossEntropyLoss()

    # optimizer = optim.Adam(model.parameters(),lr = lr)
    optimizer = optim.RMSprop(model.parameters(),lr = lr, momentum = 0.2)
    # optimizer = optim.SGD(model.parameters(), lr=1, momentum=0.5, weight_decay=5e-4)

    # model.cuda(0)
    summary(model.cuda(),(1,2,750))

    loss_history = []
    train_accuracy_history = []
    test_accuracy_history = []

    for epoch in range(epochs):
        model.train()
        x, y = x.to(device), y.to(device)
        y_pred = model(x)

        loss = criterion(y_pred, y)
        loss_history.append(loss.item())


        loss.backward()
        optimizer.step()
        optimizer.zero_grad()

        sample_cnt = y.shape[0]
        correct = (torch.max(y_pred,1)[1]==y).sum().item()
        train_accuracy = correct / sample_cnt
        train_accuracy_history.append(train_accuracy)

        test_accuracy = testing(target_data, target_label, model,device)
        test_accuracy_history.append(test_accuracy)

        print("epochs:",epoch,"loss:",loss.item(),"D_s Accuracy:",train_accuracy,"D_t Accuracy:",test_accuracy)

        max_training_accuracy = max(train_accuracy, max_training_accuracy)

        max_test_accuracy = max(test_accuracy, max_test_accuracy)

    print("The maximum train accuracy is: ", max_training_accuracy, "the maximum Test Accuracy is: ", max_test_accuracy)

source_data, source_label, target_data, target_label = read_bci_data()
train(source_data, source_label, target_data, target_label, epochs=500, lr=1e-3)

  y = torch.tensor(y, dtype=torch.long)


EEGNet_ReLU(
  (firstConv): Sequential(
    (0): Conv2d(1, 16, kernel_size=(1, 51), stride=(1, 1), padding=(0, 25), bias=False)
    (1): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  )
  (depthwiseConv): Sequential(
    (0): Conv2d(16, 32, kernel_size=(2, 1), stride=(1, 1), groups=8, bias=False)
    (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU()
    (3): AvgPool2d(kernel_size=(1, 4), stride=(1, 4), padding=0)
    (4): Dropout(p=0.35, inplace=False)
  )
  (separableConv): Sequential(
    (0): Conv2d(32, 32, kernel_size=(1, 15), stride=(1, 1), padding=(0, 7), bias=False)
    (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU()
    (3): AvgPool2d(kernel_size=(1, 8), stride=(1, 8), padding=0)
    (4): Dropout(p=0.35, inplace=False)
  )
  (classify): Sequential(
    (0): Flatten(start_dim=1, end_dim=-1)
    (1): Linear(in_features=736, out_features=2, bia

## 3.3 With DA

In [None]:
def test(model, data, label, device):
    model.eval()

    data, label = data.to(device), label.to(device)
    pred, _ = model(data, data)

    correct_cnt = (torch.max(pred, 1)[1]== label).sum().item()
    sample_cnt = data.shape[0]
    accuracy = correct_cnt / sample_cnt

    return accuracy


torch.manual_seed(1)    # reproducible
epochs = 500
lr = 1e-3

min_loss = 1
max_training_accuracy = 0
max_test_accuracy = 0
device = torch.device("cuda:0")

source_data, source_label, target_data, target_label = read_bci_data()

source_data = source_data.astype("float32")
source_label = source_label.astype("float32").reshape(source_label.shape[0],)

target_data = target_data.astype("float32")
target_label = target_label.astype("float32").reshape(target_label.shape[0],)

source_data, source_label = torch.from_numpy(source_data), torch.from_numpy(source_label)
source_label = torch.tensor(source_label, dtype=torch.long)
target_data, target_label = torch.from_numpy(target_data), torch.from_numpy(target_label)
target_label = torch.tensor(target_label, dtype=torch.long)




model = EEGNet_DDC(num_classes=2)
print(model)
summary(model.cuda(), [(1, 2, 750), (1, 2, 750)])


# optimizer = optim.Adam(model.parameters(),lr = lr)
# optimizer = optim.RMSprop(model.parameters(),lr = lr, momentum = 0.2)
optimizer = optim.SGD(model.parameters(), lr=0.5, momentum=0.5, weight_decay=5e-4)

model.cuda(0)

loss_history = []
train_accuracy_history = []
test_accuracy_history = []

lambda_factor = 0.5

for epoch in range(epochs):
    model.train()

    source_data, source_label = source_data.to(device), source_label.to(device)
    target_data = target_data.to(device)

    preds, mmd_loss = model(source_data, target_data)

    # compute losses (classification and coral loss)
    clf_loss = torch.nn.functional.cross_entropy(preds, source_label)

    # compute total loss (equation 2 in the paper paper)
    loss = clf_loss + lambda_factor * mmd_loss

    loss_history.append(loss.item())

    correct_cnt = (torch.max(preds, 1)[1]== source_label).sum().item()
    accuracy = correct_cnt / source_data.shape[0]

    # compute gradients of network (backprop in pytorch)
    loss.backward()

    # update weights of network
    optimizer.step()

    # reset optimizer gradients to zero
    optimizer.zero_grad()

    test_accuracy = test(model, target_data, target_label, device)

    print("epochs:", epoch, "classification loss:", clf_loss.item(), "MMD loss", mmd_loss.item(), "total loss:", loss.item(),
            "D_t accuracy:", test_accuracy, "D_s accuracy:", accuracy
            , "all:", test_accuracy + accuracy)


  source_label = torch.tensor(source_label, dtype=torch.long)
  target_label = torch.tensor(target_label, dtype=torch.long)


EEGNet_DDC(
  (firstConv): Sequential(
    (0): Conv2d(1, 16, kernel_size=(1, 51), stride=(1, 1), padding=(0, 25), bias=False)
    (1): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  )
  (depthwiseConv): Sequential(
    (0): Conv2d(16, 32, kernel_size=(2, 1), stride=(1, 1), groups=8, bias=False)
    (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU()
    (3): AvgPool2d(kernel_size=(1, 4), stride=(1, 4), padding=0)
    (4): Dropout(p=0.35, inplace=False)
  )
  (separableConv): Sequential(
    (0): Conv2d(32, 32, kernel_size=(1, 15), stride=(1, 1), padding=(0, 7), bias=False)
    (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU()
    (3): AvgPool2d(kernel_size=(1, 8), stride=(1, 8), padding=0)
    (4): Dropout(p=0.35, inplace=False)
  )
  (bottleneck): Sequential(
    (0): Flatten(start_dim=1, end_dim=-1)
    (1): Linear(in_features=736, out_features=256, 

## with DA (DAN)

In [None]:
def test(model, data, label, device):
    model.eval()

    data, label = data.to(device), label.to(device)
    pred, _ = model(data, data)

    correct_cnt = (torch.max(pred, 1)[1]== label).sum().item()
    sample_cnt = data.shape[0]
    accuracy = correct_cnt / sample_cnt

    return accuracy


torch.manual_seed(1)    # reproducible
epochs = 500
lr = 1e-3

min_loss = 1
max_training_accuracy = 0
max_test_accuracy = 0
device = torch.device("cuda:0")

source_data, source_label, target_data, target_label = read_bci_data()

source_data = source_data.astype("float32")
source_label = source_label.astype("float32").reshape(source_label.shape[0],)

target_data = target_data.astype("float32")
target_label = target_label.astype("float32").reshape(target_label.shape[0],)

source_data, source_label = torch.from_numpy(source_data), torch.from_numpy(source_label)
source_label = torch.tensor(source_label, dtype=torch.long)
target_data, target_label = torch.from_numpy(target_data), torch.from_numpy(target_label)
target_label = torch.tensor(target_label, dtype=torch.long)



model = EEGNet_DAN(num_classes=2)
print(model)
summary(model.cuda(), [(1, 2, 750), (1, 2, 750)])


# optimizer = optim.Adam(model.parameters(),lr = lr)
# optimizer = optim.RMSprop(model.parameters(),lr = lr, momentum = 0.2)
optimizer = optim.SGD(model.parameters(), lr=0.5, momentum=0.5, weight_decay=5e-4)

model.cuda(0)

loss_history = []
train_accuracy_history = []
test_accuracy_history = []
max_target_accuracy = 0

# lambda_factor = 0.5

for epoch in range(epochs):
    model.train()

    source_data, source_label = source_data.to(device), source_label.to(device)
    target_data = target_data.to(device)

    n_samples_source = source_data.size(0)
    n_samples_target = target_data.size(0)
    total = torch.cat([source_data, target_data], dim=0)


    preds, mmd_loss = model(source_data, target_data)
    # print("mmd_loss is", mmd_loss)

    # compute losses (classification and coral loss)
    clf_loss = torch.nn.functional.cross_entropy(preds, source_label)

    # compute total loss (equation 2 in the paper paper)
    loss = clf_loss + mmd_loss

    loss_history.append(loss.item())

    correct_cnt = (torch.max(preds, 1)[1]== source_label).sum().item()
    accuracy = correct_cnt / source_data.shape[0]

    # compute gradients of network (backprop in pytorch)
    loss.backward()

    # update weights of network
    optimizer.step()

    # reset optimizer gradients to zero
    optimizer.zero_grad()

    test_accuracy = test(model, target_data, target_label, device)
    max_target_accuracy = max(max_target_accuracy, test_accuracy)

    print("epochs:", epoch, "classification loss:", clf_loss.item(), "MMD loss", mmd_loss.item(), "total loss:", loss.item(),
            "D_t accuracy:", test_accuracy, "D_s accuracy:", accuracy
            , "all:", test_accuracy + accuracy)

print("The max target accuracy is: ", max_target_accuracy)

  source_label = torch.tensor(source_label, dtype=torch.long)
  target_label = torch.tensor(target_label, dtype=torch.long)


EEGNet_DAN(
  (firstConv): Sequential(
    (0): Conv2d(1, 16, kernel_size=(1, 51), stride=(1, 1), padding=(0, 25), bias=False)
    (1): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  )
  (depthwiseConv): Sequential(
    (0): Conv2d(16, 32, kernel_size=(2, 1), stride=(1, 1), groups=8, bias=False)
    (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU()
    (3): AvgPool2d(kernel_size=(1, 4), stride=(1, 4), padding=0)
    (4): Dropout(p=0.35, inplace=False)
  )
  (separableConv): Sequential(
    (0): Conv2d(32, 32, kernel_size=(1, 15), stride=(1, 1), padding=(0, 7), bias=False)
    (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU()
    (3): AvgPool2d(kernel_size=(1, 8), stride=(1, 8), padding=0)
    (4): Dropout(p=0.35, inplace=False)
  )
  (bottleneck1): Sequential(
    (0): Flatten(start_dim=1, end_dim=-1)
    (1): Linear(in_features=736, out_features=256,