In [1]:
%load_ext autoreload
%autoreload 2
import torch
from torch.autograd import Variable
import torch.nn as nn
import torch.nn.functional as F
import torch.optim
from Encoder.net_generator import Classifier
from torchinfo import summary
from torchvision import models

In [2]:
# Default Options
length = 512
channel = 96
min_CNN = 200
n_classes = 40
classes = range(n_classes)
classifier= ""
GPUindex=0
kind="from-scratch"

In [3]:
# Options for EEGChannelNet
length=440
channel=128
classifier="EEGChannelNet"

In [16]:
class SiameseNet(nn.Module):
    def __init__(self):
        super(SiameseNet, self).__init__()
        # Define sub-networks
        self.image_net = models.inception_v3(pretrained=True)
        self.image_net.fc = nn.Linear(self.image_net.fc.in_features, 256)
        self.eeg_net, _ = Classifier(
                    n_classes,
                    classes,
                    classifier,
                    GPUindex,
                    length,
                    channel,
                    min_CNN,
                    kind)

    def forward(self, img, eeg):
        img_out = self.image_net(img)
        eeg_out = self.eeg_net(eeg)
        return img_out, eeg_out

In [17]:
# Instantiate the network and loss function
sample_eeg = torch.randn(1, 128, 440)
sample_img = torch.randn(1,3,299,299)
siamese_net = SiameseNet()
summary(siamese_net, input_data=[sample_img, sample_eeg])

Downloading: "https://download.pytorch.org/models/inception_v3_google-0cc3c7bd.pth" to /home/titan/.cache/torch/hub/checkpoints/inception_v3_google-0cc3c7bd.pth


  0%|          | 0.00/104M [00:00<?, ?B/s]

DONE: CREATE TORCH CLASSIFIER
classifier_EEGChannelNet(
  (encoder): FeaturesExtractor(
    (temporal_block): TemporalBlock(
      (layers): ModuleList(
        (0): ConvLayer2D(
          (norm): BatchNorm2d(1, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (relu): ReLU(inplace=True)
          (conv): Conv2d(1, 10, kernel_size=(1, 33), stride=(1, 2), padding=(0, 16))
          (drop): Dropout2d(p=0.2, inplace=False)
        )
        (1): ConvLayer2D(
          (norm): BatchNorm2d(1, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (relu): ReLU(inplace=True)
          (conv): Conv2d(1, 10, kernel_size=(1, 33), stride=(1, 2), padding=(0, 32), dilation=(1, 2))
          (drop): Dropout2d(p=0.2, inplace=False)
        )
        (2): ConvLayer2D(
          (norm): BatchNorm2d(1, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (relu): ReLU(inplace=True)
          (conv): Conv2d(1, 10, kernel_size=(1, 33), stride=

Layer (type:depth-idx)                             Output Shape              Param #
SiameseNet                                         [1, 256]                  --
├─Inception3: 1-1                                  [1, 256]                  3,326,696
│    └─BasicConv2d: 2-1                            [1, 32, 149, 149]         --
│    │    └─Conv2d: 3-1                            [1, 32, 149, 149]         864
│    │    └─BatchNorm2d: 3-2                       [1, 32, 149, 149]         64
│    └─BasicConv2d: 2-2                            [1, 32, 147, 147]         --
│    │    └─Conv2d: 3-3                            [1, 32, 147, 147]         9,216
│    │    └─BatchNorm2d: 3-4                       [1, 32, 147, 147]         64
│    └─BasicConv2d: 2-3                            [1, 64, 147, 147]         --
│    │    └─Conv2d: 3-5                            [1, 64, 147, 147]         18,432
│    │    └─BatchNorm2d: 3-6                       [1, 64, 147, 147]         128
│    └─MaxPool2d: 2