In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader

# Define the GCN model
class GCN(nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim):
        super(GCN, self).__init__()
        self.gcn1 = nn.Linear(input_dim, hidden_dim)
        self.gcn2 = nn.Linear(hidden_dim, output_dim)
        self.relu = nn.ReLU()

    def forward(self, x, adj):
        x = self.gcn1(x)
        x = self.relu(torch.mm(adj, x))
        x = self.gcn2(x)
        return x

# Define the CNN model
class CNN(nn.Module):
    def __init__(self):
        super(CNN, self).__init__()
        self.conv1 = nn.Conv1d(in_channels=4, out_channels=16, kernel_size=3, padding=1)
        self.pool = nn.MaxPool1d(kernel_size=2, stride=2)
        self.conv2 = nn.Conv1d(in_channels=16, out_channels=32, kernel_size=3, padding=1)
        self.fc1 = nn.Linear(32 * 25, 128)
        self.fc2 = nn.Linear(128, 64)

    def forward(self, x):
        x = self.conv1(x)
        x = self.pool(x)
        x = self.conv2(x)
        x = self.pool(x)
        x = x.view(-1, 32 * 25)
        x = self.fc1(x)
        x = self.fc2(x)
        return x

# Define the combined GCN-CNN model
class GCNCNN(nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim):
        super(GCNCNN, self).__init__()
        self.gcn = GCN(input_dim, hidden_dim, output_dim)
        self.cnn = CNN()

    def forward(self, x, adj):
        cnn_out = self.cnn(x.unsqueeze(0))
        x = torch.cat((x, cnn_out.squeeze(0)), dim=1)
        x = self.gcn(x, adj)
        return x

# Define the dataset
class VirusDataset(torch.utils.data.Dataset):
    def __init__(self, features, adj_matrix, labels):
        self.features = features
        self.adj_matrix = adj_matrix
        self.labels = labels

    def __getitem__(self, index):
        return self.features[index], self.adj_matrix[index], self.labels[index]

    def __len__(self):
        return len(self.features)

# Define the training function
def train(model, device, train_loader, optimizer, criterion):
    model.train()
    running_loss = 0.0
    for data in train_loader:
        features, adj_matrix, labels = data
        features, adj_matrix, labels = features.to(device), adj_matrix.to(device), labels.to(device)
        optimizer.zero_grad()
        output = model(features, adj_matrix)
        loss = criterion(output, labels)
        loss.backward()
        optimizer.step()
        running_loss += loss.item()
    return running_loss/len(train_loader)

# Define the testing function
def test(model, device, test_loader, criterion):
    model.eval()
    correct = 0
    total = 0
    running_loss = 0.0
    with torch.no_grad():
        for data in test:
            features, adj_matrix, labels = data
        features, adj_matrix, labels = features.to(device), adj_matrix.to(device), labels.to(device)
        output = model(features, adj_matrix)
        loss = criterion(output, labels)
        running_loss += loss.item()
        _, predicted = torch.max(output.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()
    return running_loss/len(test_loader), 100 * correct/total