In [4]:
import numpy as np
from tqdm import tqdm
import torch
import torch.nn as nn
from model_utility import *
from model_dataloader import *
from torch.utils.data import DataLoader

dataset = GraphList("./data/AM.npy", "./data/train.txt", "./data/test.txt")
train_list, valid_list, test_list = dataset.make_dataset()

batch_size = 8
train_dataset = GraphDataset(train_list, batch_size)
valid_dataset = GraphDataset(valid_list, batch_size)
test_dataset  = GraphDataset(test_list, batch_size)

train_loader  = DataLoader(train_dataset, batch_size = batch_size, shuffle = True, drop_last = True)
valid_loader  = DataLoader(valid_dataset, batch_size = batch_size, shuffle = False, drop_last = True)
test_loader   = DataLoader(test_dataset, batch_size = batch_size, shuffle = False, drop_last = True)


class Net(nn.Module):
    def __init__(self, hidden_dim = [492, 512, 256, 128, 64, 3]):
        super(Net, self).__init__()
        self.hidden_dim = hidden_dim
        self.gelu    = nn.GELU()
        self.softmax = nn.Softmax(dim = 1)

        self.linear1 = nn.Linear(492, self.hidden_dim[0])
        self.linear2 = nn.Linear(self.hidden_dim[0], self.hidden_dim[1])
        self.linear3 = nn.Linear(self.hidden_dim[1], self.hidden_dim[2])
        self.linear4 = nn.Linear(self.hidden_dim[2], self.hidden_dim[3])
        self.linear5 = nn.Linear(self.hidden_dim[3], self.hidden_dim[4])
        self.linear6 = nn.Linear(self.hidden_dim[4], self.hidden_dim[5])
        # self.linear7 = nn.Linear(self.hidden_dim[4], self.hidden_dim[5])
        self.sequential = nn.Sequential(
                                self.linear1,
                                self.gelu,
                                self.linear2,
                                self.gelu,
                                self.linear3,
                                self.gelu,
                                self.linear4,
                                self.gelu,
                                self.linear5,
                                self.gelu,
                                self.linear6)

        self.last_linear = nn.Linear(9, 3)
    def forward(self, graph):
        out = self.sequential(graph)
        out = out.permute(0, 2, 1)
        out = self.sequential(out)

        B, H, W = out.shape
        out = out.view(B, -1)
        out = self.last_linear(out)
        out = self.softmax(out)
        return out

model = Net()
test  = torch.ones([8, 492, 492])
out   = model(test)

optimizer = torch.optim.Adam(model.parameters(), lr = 0.001)
criterion   = nn.CrossEntropyLoss()

In [3]:
for epoch in tqdm(range(100)):
    train_loss     = 0
    train_accuracy = 0

    for index, (graph, label) in enumerate(train_loader):
        optimizer.zero_grad()

        pred  = model(graph)
        # print(pred.shape)
        # print(label.shape)
        # print(torch.argmax(pred, 1).shape)
        loss = criterion(pred, label)
        loss.backward()
        optimizer.step()

        correct_prediction = torch.argmax(pred, 1) == label
        train_loss         = train_loss + (loss / len(train_loader))
        train_accuracy     = train_accuracy + correct_prediction.float().mean()
    print("Epoch {0} Train loss {1:0.5f} Train accuracy {2:0.5f}".format(epoch, train_loss, train_accuracy))

0


KeyboardInterrupt: 