In [None]:
import torch
import numpy as np
import torch.nn as nn
from torch import optim
from torch.utils.data import Dataset, DataLoader

In [None]:
evalx = np.load("../data/dev.npy", allow_pickle=True)
evaly = np.load("../data/dev_labels.npy", allow_pickle=True)

In [None]:
class simpleDataset(Dataset):
    def __init__(self, x, y, context_size=12):
        super().__init__()
        self.x = np.concatenate(x)
        self.y = np.concatenate(y)
        self.context_size = context_size
        
        # padding
        before = np.zeros((context_size, self.x.shape[1]))
        after = np.zeros((context_size, self.x.shape[1]))
        self.x = np.concatenate([before, self.x, after], axis=0)
    
    def __len__(self):
        return len(self.x) - 2 * self.context_size

    def __getitem__(self, index):
        x_item = self.x[index : (index+2*self.context_size+1)].reshape(-1)
        y_item = self.y[index]
        return torch.from_numpy(x_item).float(), y_item

class MLP(nn.Module):
    def __init__(self, context_size):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(1000, 1024),
            nn.ReLU(),
            nn.Linear(1024, 1024),
            nn.ReLU(),
            nn.Linear(1024, 1024),
            nn.ReLU(),
            nn.Linear(1024, 138)
        )
    
    def forward(self, input):
        return self.net(input)

In [None]:
torch.manual_seed(0)

eval_dataset = simpleDataset(evalx[:5], evaly[:5], context_size=12)
eval_dataloader = DataLoader(eval_dataset, batch_size=256, shuffle=True, num_workers=0)

model = MLP(context_size=12)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=1e-3)
curr_epoch = 0

In [None]:
# for epoch
for i in range(100):
    running_loss = 0.0
    total, correct = 0, 0
    for batch_idx, (data, target) in enumerate(eval_dataloader):
        optimizer.zero_grad()
        
        outputs = model(data)
        loss = criterion(outputs, target)
        loss.backward()
        optimizer.step()
        
        running_loss += loss.item()
        
        _, predicted = torch.max(outputs.data, 1)
        predicted.detach_()
        total += target.size(0)
        correct += (predicted == target).sum().item()
    curr_epoch += 1
    print("{}th epoch ".format(curr_epoch) +
          "loss: {:.4} ".format(running_loss / len(eval_dataloader)) + 
          "acc: {:.4}%".format(correct / total * 100))

In [None]:
with torch.no_grad():
    total, correct = 0, 0
    for batch_idx, (data, target) in enumerate(eval_dataloader):
        outputs = model(data)
        _, predicted = torch.max(outputs.data, 1)
        total += target.size(0)
        correct += (predicted == target).sum().item()

print("accuracy: {}%".format(correct/total * 100))