In [51]:
import torch
import torch.nn as nn
from torch.optim import Optimizer
from torch.utils.data import DataLoader


device = torch.device("cuda:0" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu")
# Set the random seed
torch.manual_seed(0)

'''
assume that datasets are given, after preprocessing
'''
# Set the random seed
torch.manual_seed(0)

# Create a random dataset
dataset_size = 1200
data = torch.randn(dataset_size, 22, 128, 8)

labels = torch.randint(2, (dataset_size,))
train_dataset = list(zip(data[:500], torch.zeros(500, dtype=torch.int)))
train_dataset += list(zip(data[500:1000], torch.ones(500, dtype=torch.int)))

test_dataset = list(zip(data[1000:1100], torch.zeros(200, dtype=torch.int)))
test_dataset += list(zip(data[1100:], torch.ones(200, dtype=torch.int)))

In [58]:


class SciCNN(nn.Module):

    def __init__(self):
        super(SciCNN, self).__init__()        

        self.inception1 = Inception(8, 8, 16, 8, 8)
        self.maxpool1 = nn.MaxPool2d((4, 1), stride=(4, 1), ceil_mode=True)
        self.inception2 = Inception(16, 16, 8, 16, 4)
        self.maxpool2 = nn.MaxPool2d((4, 1), stride=(4, 1), ceil_mode=True)
        self.inception3 = Inception(32, 32, 4, 32, 2)
        self.maxpool3 = nn.MaxPool2d((8, 1), stride=(8, 1), ceil_mode=True)
        self.flatten = nn.Flatten()
        self.dropout = nn.Dropout(0.4)
        self.fc = nn.Linear(22*64, 16)
        self.npc = NPC()
        
        for m in self.modules():
            if isinstance(m, nn.Conv1d) or isinstance(m, nn.Linear):
                torch.nn.init.xavier_uniform_(m.weight)

    def forward(self, x):
        x = self.inception1(x)
        x = self.maxpool1(x)
        x = self.inception2(x)
        x = self.maxpool2(x)
        x = self.inception3(x)
        x = self.maxpool3(x)
        x = self.flatten(x)
        x = self.dropout(x)
        x = self.fc(x)
        return x

class Inception(nn.Module):

    def __init__(self, in_channels, ch1, ch1_kernel, ch2, ch2_kernel):
        super(Inception, self).__init__()
        self.branch1 = BasicConv1d(in_channels, ch1, kernel_size=(ch1_kernel, 1))
        self.branch2 = BasicConv1d(in_channels, ch2, kernel_size=(ch2_kernel, 1))

    def forward(self, x):
        branch1 = self.branch1(x)
        branch2 = self.branch2(x)
        return torch.cat([branch1, branch2], dim=3)

class BasicConv1d(nn.Module):
    def __init__(self, in_channels, out_channels, **kwargs):
        super(BasicConv1d, self).__init__()
        self.conv = nn.Sequential(
                            nn.Conv2d(in_channels, out_channels, **kwargs, bias=False),
                            nn.BatchNorm2d(out_channels),
                            nn.ReLU()
                            )
    def forward(self, x):
        return self.conv(x)
    
class NPC(nn.Module):
    def __init__(self, num_clusters=256):
        super(NPC, self).__init__()
        # 256 predefined positions of NPC clusters
        self.position = nn.Parameter(torch.randn(num_clusters, 16, 1))
        self.label = nn.Parameter(torch.randint(1, (num_clusters,)), requires_grad=False)

model = SciCNN().to(device)

def npc_loss(output, label, model):
    # output: (batch    _size, 16, 1)
    distances = torch.norm(output - model.npc.position, dim=1)
    closest_position_index = torch.argmin(distances)
    closest_position = model.npc.position[closest_position_index]
    model.npc.label[closest_position_index] = torch.mean(label)
    loss = torch.norm(output - closest_position)
    return loss
    

model = SciCNN().to(device)
loss_function = npc_loss
optimizer = torch.optim.SGD(model.parameters(), lr=0.1, weight_decay=5e-4)

model.train()
train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=50, shuffle=False)


import time
start = time.time()
for epoch in range(100) :
    print("{}th epoch starting.".format(epoch))
    for i, (images, labels) in enumerate(train_loader) :
        images, labels = images.to(device), labels.to(device)

        optimizer.zero_grad()
        train_loss = loss_function(model(images), labels, model)
        train_loss.backward()

        optimizer.step()

    print ("Epoch [{}] Loss: {:.4f}".format(epoch+1, train_loss.item()))

end = time.time()
print("Time ellapsed in training is: {}".format(end - start))


model.eval()
test_loss, correct, total = 0, 0, 0

test_loader = torch.utils.data.DataLoader(dataset=test_dataset, batch_size=100, shuffle=False)
with torch.no_grad():  #using context manager
    for images, labels in test_loader :
        images, labels = images.to(device), labels.to(device)

        output = model(images)
        test_loss += loss_function(output, labels).item()

        pred = output.max(1, keepdim=True)[1]
        correct += pred.eq(labels.view_as(pred)).sum().item()

        total += labels.size(0)

print('[Test set] Average loss: {:.4f}, Accuracy: {}/{} ({:.2f}%)\n'.format(
        test_loss /total, correct, total,
        100. * correct / total))

0th epoch starting.


RuntimeError: Given groups=1, weight of size [8, 8, 16, 1], expected input[50, 22, 128, 8] to have 8 channels, but got 22 channels instead

In [25]:
import torch
import torch.nn as nn
x = torch.randn(3, 16, 1)
torch.sum((x - nn.Parameter(torch.randn(3, 16, 1))) ** 2, dim=1)

tensor([[21.0856],
        [19.9419],
        [39.3159]], grad_fn=<SumBackward1>)