## Toy problem

In [1]:
import torch
import numpy

In [2]:
if torch.cuda.is_available():
    device = torch.device('cuda:0')
    print('Running on the GPU.')
else:
    device = torch.device('cpu')
    print('Running on the CPU.')

Running on the GPU.


In [3]:
vector_length = 5
models_ids = ['A', 'B', 'C'] 
#models = numpy.asarray([[43 for i in range(vector_length)], [-828 for i in range(vector_length)], [453 for i in range(vector_length)]], dtype = float)
models = numpy.asarray([numpy.random.normal(43, 1.0, vector_length), numpy.random.normal(-828, 1.0, vector_length), numpy.random.normal(453, 1.0, vector_length)], dtype = float)
data = {}
data_map = []
data_length = 0
for idx, model_id in enumerate(models_ids):
    data[model_id] = {}
    data[model_id]['model_id'] = model_id
    data[model_id]['model'] = models[idx]
    #data[model_id]['vectors'] = [numpy.random.normal(data[model_id]['model'][0], 5.0, vector_length) for i in range(10000)]
    data[model_id]['vectors'] = [numpy.asarray([value for value in data[model_id]['model']]) for i in range(10000)]
    length = len(data[model_id]['vectors'])
    data_map.append((data_length, data_length + length - 1, model_id))
    data_length += length

In [4]:
from torch.utils.data import Dataset, DataLoader, random_split

class Toy_dataset(Dataset):
    def __init__(self, data, data_length, data_map):
        self.data = data
        self.data_length = data_length
        self.data_map = data_map
    def __len__(self):
        return self.data_length
    def __getitem__(self, idx):
        data_tuple = list(filter(lambda data_tuple: data_tuple[0] <= idx and idx <= data_tuple[1], self.data_map))[0]
        data_idx = idx - data_tuple[0]
        data_id = data_tuple[2]
        entry = data[data_id]
        vector = entry['vectors'][data_idx]
        permutations = [model_id for model_id in data]
        models = [data[model_id]['model'] for model_id in permutations]
        permutations = numpy.asarray([model_id == entry['model_id'] for model_id in permutations], dtype = float)
        
        x = numpy.concatenate([vector] + models)
        y = permutations
        
        return x, y
        
toy_dataset = Toy_dataset(data, data_length, data_map)
print(toy_dataset[55])

train_length = int(len(toy_dataset) * 0.7)
test_length = len(toy_dataset) - train_length
cross_length = int(train_length * 0.3)
train_length = train_length - cross_length

train_dataset, cross_dataset, test_dataset = random_split(toy_dataset, [train_length, cross_length, test_length])
train_dataloader = DataLoader(train_dataset, batch_size = 10, shuffle=True, num_workers = 4)
cross_dataloader = DataLoader(cross_dataset, batch_size = len(cross_dataset), num_workers = 4)

(array([  43.3291097 ,   44.25739582,   43.18322558,   43.82884937,
         43.86099993,   43.3291097 ,   44.25739582,   43.18322558,
         43.82884937,   43.86099993, -826.96909435, -828.50882847,
       -827.20874069, -829.50712132, -829.01829793,  451.1246898 ,
        451.36999668,  451.37096499,  451.03752116,  451.31414493]), array([1., 0., 0.]))


In [5]:
import torch.nn as nn
import torch.nn.functional as F

class Net(nn.Module):
    def __init__(self):
        super().__init__()
        self.fc0 = nn.Linear(4 * vector_length, 4)
        self.fc1 = nn.Linear(4, 4)
        self.fc2 = nn.Linear(4, 4)
        self.fc3 = nn.Linear(4, 3)
    def forward(self, x):
        x = F.relu(self.fc0(x))
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = F.sigmoid(self.fc3(x))
        return x

In [6]:
import torch.optim as optim

net = Net().to(device)

optimizer = optim.Adam(net.parameters(), lr = 0.0001)
criterion = nn.BCELoss()

epochs = 10
for epoch in range(epochs):
    for input, target in train_dataloader:
        input = input.to(device, non_blocking=True).float()
        target = target.to(device, non_blocking=True).float()

        net.zero_grad()

        output = net(input)

        loss = criterion(output, target)

        loss.backward()
        optimizer.step()
    with torch.no_grad():
        for input, target in cross_dataloader:
            input = input.to(device, non_blocking=True).float()
            target = target.to(device, non_blocking=True).float()
            output = net(input)
            cross_loss = criterion(output, target)
        
    print('epoch:', epoch, 'loss:', loss, 'cross_loss:', cross_loss)
    
    if cross_loss < 0.01:
        print('Done training.')
        break



epoch: 0 loss: tensor(0.1824, device='cuda:0', grad_fn=<BinaryCrossEntropyBackward>) cross_loss: tensor(0.2753, device='cuda:0')
epoch: 1 loss: tensor(0.1410, device='cuda:0', grad_fn=<BinaryCrossEntropyBackward>) cross_loss: tensor(0.2340, device='cuda:0')
epoch: 2 loss: tensor(0.3157, device='cuda:0', grad_fn=<BinaryCrossEntropyBackward>) cross_loss: tensor(0.2119, device='cuda:0')
epoch: 3 loss: tensor(0.1147, device='cuda:0', grad_fn=<BinaryCrossEntropyBackward>) cross_loss: tensor(0.1920, device='cuda:0')
epoch: 4 loss: tensor(0.1558, device='cuda:0', grad_fn=<BinaryCrossEntropyBackward>) cross_loss: tensor(0.1741, device='cuda:0')
epoch: 5 loss: tensor(3.9722e-05, device='cuda:0', grad_fn=<BinaryCrossEntropyBackward>) cross_loss: tensor(0.1574, device='cuda:0')
epoch: 6 loss: tensor(0.2120, device='cuda:0', grad_fn=<BinaryCrossEntropyBackward>) cross_loss: tensor(0.1421, device='cuda:0')
epoch: 7 loss: tensor(0.0764, device='cuda:0', grad_fn=<BinaryCrossEntropyBackward>) cross_lo

In [7]:
test_dataloader = DataLoader(test_dataset, batch_size = len(test_dataset), num_workers = 4)
with torch.no_grad():
    for input, target in test_dataloader:
        input = input.to(device, non_blocking=True).float()
        target = target.to(device, non_blocking=True).float()
        output = net(input)
        test_loss = criterion(output, target)
    print('test_loss:', test_loss)

test_loss: tensor(0.1038, device='cuda:0')


In [8]:
test_dataloader = DataLoader(test_dataset, batch_size = 1, num_workers = 1)
correct = 0
with torch.no_grad():
    for input, target in test_dataloader:
        input = input.to(device, non_blocking=True).float()
        target = target.to(device, non_blocking=True).float()
        output = net(input)
        if target.max(1)[1] == output.max(1)[1]:
            correct += 1
print(correct / len(test_dataset))

1.0
