In [45]:
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import TensorDataset, DataLoader
from tqdm.notebook import tqdm

In [29]:
def get_dataset(size, seq_len, max_value=255):
    X = torch.zeros((size, seq_len))
    y = torch.zeros((size, seq_len))
    for i in range(size):
        seq = np.random.choice(max_value, size=seq_len, replace=False) / 255
        shuffled_seq = seq.copy()
        np.random.shuffle(shuffled_seq)
        X[i] = torch.from_numpy(shuffled_seq)
        y[i] = torch.from_numpy(seq)
    return X, y

In [37]:
seq_len = 3
size = 1000
get_dataset(size, seq_len)
train = DataLoader(TensorDataset(*get_dataset(size, seq_len)), batch_size=32, shuffle=True)
test = DataLoader(TensorDataset(*get_dataset(size, seq_len)), batch_size=32, shuffle=True)

In [47]:
net = nn.Sequential(nn.Linear(seq_len, seq_len ** 2), nn.ReLU(), nn.Linear(seq_len ** 2, seq_len ** 2), nn.ReLU(),  nn.Linear(seq_len ** 2, seq_len))
opt = optim.Adam(net.parameters(), lr=0.01)
crit = nn.MSELoss()

In [48]:
n_epochs = 100
pbar = tqdm(range(n_epochs))
for epoch in pbar:
    train_loss = []
    test_loss = []
    for x, y in train:
        opt.zero_grad()
        o = net(x)
        loss = crit(y, o)
        loss.backward()
        opt.step()
        train_loss.append(loss)
    with torch.no_grad():
        for x, y in test:
            o = net(x)
            loss = crit(y, o)
        test_loss.append(loss)
    pbar.set_description(f"Epoch {epoch + 1}: Train: {torch.tensor(train_loss).mean().item():.3f}, Test: {torch.tensor(test_loss).mean().item():.3f}")

HBox(children=(HTML(value=''), FloatProgress(value=0.0), HTML(value='')))




In [49]:
x, y = next(iter(test))
preds = net(x)
for i in range(len(x)):
    print(y[i] * 255, preds[i] * 255)

tensor([91., 29., 61.]) tensor([51.9880, 56.5139, 62.5566], grad_fn=<MulBackward0>)
tensor([178., 121., 148.]) tensor([137.7622, 146.3967, 151.4230], grad_fn=<MulBackward0>)
tensor([196.,   6.,  87.]) tensor([ 86.6283,  85.5099, 107.5699], grad_fn=<MulBackward0>)
tensor([ 35., 184., 108.]) tensor([101.7757, 101.9018, 117.9278], grad_fn=<MulBackward0>)
tensor([61., 56., 88.]) tensor([59.0683, 66.3754, 71.9881], grad_fn=<MulBackward0>)
tensor([201., 252.,  18.]) tensor([144.5366, 144.9789, 175.2344], grad_fn=<MulBackward0>)
tensor([225.,  60.,   3.]) tensor([ 84.3926,  82.9001, 112.6064], grad_fn=<MulBackward0>)
tensor([172.,  23.,  60.]) tensor([72.1615, 73.2473, 93.7865], grad_fn=<MulBackward0>)
tensor([ 77., 197., 149.]) tensor([133.2974, 136.9729, 144.6756], grad_fn=<MulBackward0>)
tensor([109., 112., 228.]) tensor([136.7419, 148.0990, 151.6406], grad_fn=<MulBackward0>)
tensor([204., 128., 213.]) tensor([166.6475, 176.0030, 185.5202], grad_fn=<MulBackward0>)
tensor([198., 206., 227.]