In [70]:
import torch
import torch.nn as nn
import torch.nn.functional as F

In [166]:
def generate_task(size, w=1, h=1, r=-1, c=-1):
    inp = torch.zeros(size)
    row = r if r > -1 else torch.randint(high=inp.shape[0]+1-h, size=(1,))
    col = c if c > -1 else torch.randint(high=inp.shape[1]+1-w, size=(1,))
    inp[row:row+h, col:col+w] = 1
    y = (row+h / 2) / inp.shape[0]
    x = (col+w / 2) / inp.shape[1]
    return inp, (x,y)

def generate_task_multi(n, size, w=1, h=1):
    tasks = [generate_task(size, w, h) for _ in range(n)]
    inputs = torch.cat([t[0].unsqueeze(0) for t in tasks])
    targets = torch.cat([torch.Tensor(t[1]).unsqueeze(0) for t in tasks])
    return inputs, targets

In [81]:
generate_task((3,3),r=1,c=2)

(tensor([[0., 0., 0.],
         [0., 0., 1.],
         [0., 0., 0.]]),
 (0.8333333333333334, 0.5))

In [167]:
generate_task_multi(2, (3,3))

(tensor([[[0., 1., 0.],
          [0., 0., 0.],
          [0., 0., 0.]],
 
         [[0., 0., 0.],
          [0., 0., 1.],
          [0., 0., 0.]]]),
 tensor([[0.5000, 0.1667],
         [0.8333, 0.5000]]))

In [230]:
class SimpleNet(nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.l1 = nn.Linear(9, 6)
        # self.l1 = nn.Linear(9, 10)
        # self.l2 = nn.Linear(10, 4)
        self.l3 = nn.Linear(6, 2)

    def forward(self, x):
        x = self.l1(x)
        x = F.relu(x)
        # x = self.l2(x)
        # x = F.relu(x)
        x = self.l3(x)
        return F.sigmoid(x)

In [234]:
n = SimpleNet()
n(torch.rand(3,3).flatten())

tensor([0.4152, 0.3778], grad_fn=<SigmoidBackward0>)

In [254]:
list(n.l1.parameters())

[Parameter containing:
 tensor([[-0.0279, -0.1615, -0.1847,  0.1738, -0.1177, -0.1456, -0.0497,  0.2570,
           0.2494],
         [-0.0864, -0.2241, -0.0113,  0.1516, -0.0625,  0.0921,  0.0319,  0.3319,
          -0.2088],
         [ 0.0308, -0.3119,  0.2688,  0.1937,  0.0335, -0.1324,  0.2974,  0.0052,
           0.0518],
         [-0.2810,  0.3037,  0.0787, -0.1210, -0.0253,  0.1470, -0.2797,  0.2380,
           0.1611],
         [ 0.2739, -0.3057, -0.1241,  0.0023, -0.1075, -0.0575,  0.2842, -0.1365,
           0.3323],
         [ 0.2377, -0.2173, -0.2285, -0.3108, -0.0778, -0.0553,  0.1110,  0.3067,
           0.0884]], requires_grad=True),
 Parameter containing:
 tensor([ 0.2140, -0.0766,  0.2614, -0.3035, -0.1236,  0.1785],
        requires_grad=True)]

In [242]:
net = SimpleNet()
optimizer = torch.optim.Adam(net.parameters(), lr=1e-1)
batch_size = 1000

for ep in range(200):
    inputs, targets = generate_task_multi(batch_size, (3,3))
    outputs = net(inputs.reshape(batch_size, -1))
    optimizer.zero_grad()
    l = F.mse_loss(outputs, targets)
    l.backward()
    optimizer.step()
    print(l.item())

0.072393998503685
0.06584163010120392
0.05400797724723816
0.041952814906835556
0.03077555261552334
0.021191036328673363
0.012576097622513771
0.008607640862464905
0.00665811263024807
0.005337037146091461
0.004860179964452982
0.004614601843059063
0.005133682396262884
0.006384794134646654
0.007100298069417477
0.006778639741241932
0.006301989778876305
0.006103878375142813
0.00553039601072669
0.00521641131490469
0.0049753496423363686
0.0040978481993079185
0.003960723988711834
0.0035063521936535835
0.003202746156603098
0.0025661771651357412
0.0022413444239646196
0.0018492850940674543
0.0017365488456562161
0.0014501906698569655
0.0015084551414474845
0.0015793148195371032
0.0017122459830716252
0.001788437832146883
0.001918598311021924
0.0018956727581098676
0.0016922680661082268
0.0015106796054169536
0.0013448076788336039
0.0013105530524626374
0.0011827106354758143
0.001182838692329824
0.0010219775140285492
0.0010235029039904475
0.0009661876829341054
0.0009429123601876199
0.0009223556844517589


In [255]:
torch.cat([torch.eye(3).unsqueeze(0), torch.eye(3).unsqueeze(0)]).reshape(2, -1)

tensor([[1., 0., 0., 0., 1., 0., 0., 0., 1.],
        [1., 0., 0., 0., 1., 0., 0., 0., 1.]])

In [275]:
t, r = generate_task((3,3))
t, r, net(t.flatten())

(tensor([[0., 0., 0.],
         [0., 0., 0.],
         [1., 0., 0.]]),
 (tensor([0.1667]), tensor([0.8333])),
 tensor([0.1667, 0.8333], grad_fn=<SigmoidBackward0>))