In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F

In [2]:
def generate_task(size, w=1, h=1, r=-1, c=-1):
    inp = torch.zeros(size)
    row = r if r > -1 else torch.randint(high=inp.shape[0]+1-h, size=(1,))
    col = c if c > -1 else torch.randint(high=inp.shape[1]+1-w, size=(1,))
    inp[row:row+h, col:col+w] = 1
    y = (row+h / 2) / inp.shape[0]
    x = (col+w / 2) / inp.shape[1]
    return inp, (x,y)

def generate_task_multi(n, size, w=1, h=1):
    tasks = [generate_task(size, w, h) for _ in range(n)]
    inputs = torch.cat([t[0].unsqueeze(0) for t in tasks])
    targets = torch.cat([torch.Tensor(t[1]).unsqueeze(0) for t in tasks])
    return inputs, targets

In [3]:
generate_task((3,3),r=1,c=2)

(tensor([[0., 0., 0.],
         [0., 0., 1.],
         [0., 0., 0.]]),
 (0.8333333333333334, 0.5))

In [4]:
generate_task_multi(2, (3,3))

(tensor([[[0., 1., 0.],
          [0., 0., 0.],
          [0., 0., 0.]],
 
         [[0., 0., 0.],
          [0., 0., 0.],
          [0., 1., 0.]]]),
 tensor([[0.5000, 0.1667],
         [0.5000, 0.8333]]))

In [5]:
class SimpleNet(nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.l1 = nn.Linear(9, 6)
        # self.l1 = nn.Linear(9, 10)
        # self.l2 = nn.Linear(10, 4)
        self.l3 = nn.Linear(6, 2)

    def forward(self, x):
        x = self.l1(x)
        x = F.relu(x)
        # x = self.l2(x)
        # x = F.relu(x)
        x = self.l3(x)
        return F.sigmoid(x)

In [6]:
n = SimpleNet()
n(torch.rand(3,3).flatten())



tensor([0.5370, 0.4988], grad_fn=<SigmoidBackward0>)

In [7]:
list(n.l1.parameters())

[Parameter containing:
 tensor([[-0.3328, -0.0636,  0.2452,  0.0604, -0.1827,  0.0234,  0.1259, -0.1498,
          -0.3225],
         [ 0.1021,  0.1426, -0.2504,  0.2056, -0.2841,  0.1154,  0.0356,  0.1212,
           0.1263],
         [ 0.0772,  0.0221, -0.0717,  0.2704, -0.0884, -0.0422,  0.2128, -0.1992,
           0.0631],
         [-0.0301, -0.0435, -0.2483, -0.2108, -0.2118, -0.2234,  0.0784,  0.0448,
          -0.0629],
         [ 0.2877, -0.2847, -0.1376,  0.0502, -0.2650, -0.0140,  0.1406,  0.1641,
          -0.0381],
         [ 0.2274,  0.0945,  0.1834,  0.0388,  0.2971, -0.0749, -0.2361, -0.3216,
           0.3275]], requires_grad=True),
 Parameter containing:
 tensor([-0.0329,  0.2423,  0.0746,  0.1044,  0.1100, -0.0199],
        requires_grad=True)]

In [None]:
net = SimpleNet()
optimizer = torch.optim.Adam(net.parameters(), lr=1e-1)
batch_size = 1000

for ep in range(200):
    inputs, targets = generate_task_multi(batch_size, (3,3))
    outputs = net(inputs.reshape(batch_size, -1))
    optimizer.zero_grad()
    l = F.mse_loss(outputs, targets)
    l.backward()
    optimizer.step()
    print(l.item())

In [14]:
list(net.parameters())

[Parameter containing:
 tensor([[-8.1249e-01,  9.4146e-01,  1.5088e+00, -9.4994e-01, -7.2831e-01,
          -1.0574e+00, -1.2965e-02, -7.6676e-01, -7.2392e-01],
         [-1.1611e+00, -4.5299e-01,  2.3863e-01, -2.4867e-01, -1.7436e-01,
           1.0501e+00, -1.5058e+00,  7.6804e-01,  1.9277e+00],
         [-2.9835e-01, -2.2615e-01, -2.0322e-01,  1.2935e+00, -5.4868e-01,
          -2.3041e-01, -1.3808e-01, -2.5425e-01, -7.7509e-01],
         [-2.2390e-03, -4.8303e-01, -1.8564e-01, -4.6475e-01, -3.2737e-01,
           1.3427e+00, -4.5625e-01, -4.3538e-01,  1.1170e-03],
         [-6.6414e-01, -1.3722e-01, -5.0046e-01, -6.9698e-01, -2.1085e+00,
          -4.4774e-01,  9.6433e-01,  7.7577e-01,  1.8755e-01],
         [ 1.0511e+00,  7.1986e-01,  3.8167e-01, -1.6978e-01, -1.2524e+00,
           1.4188e+00, -1.1793e+00,  7.6292e-01,  1.2019e+00]],
        requires_grad=True),
 Parameter containing:
 tensor([ 0.3085,  0.9056, -0.2729, -0.4905,  1.5498,  0.8669],
        requires_grad=True),
 Pa

In [9]:
torch.cat([torch.eye(3).unsqueeze(0), torch.eye(3).unsqueeze(0)]).reshape(2, -1)

tensor([[1., 0., 0., 0., 1., 0., 0., 0., 1.],
        [1., 0., 0., 0., 1., 0., 0., 0., 1.]])

In [13]:
t, r = generate_task((3,3))
t, r, net(t.flatten())

(tensor([[0., 0., 0.],
         [1., 0., 0.],
         [0., 0., 0.]]),
 (tensor([0.1667]), tensor([0.5000])),
 tensor([0.1667, 0.5000], grad_fn=<SigmoidBackward0>))