In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F

In [2]:
class CNN(nn.Module):
    def __init__(self, in_features, rows, cols, device = torch.device("cpu")):
        super(CNN, self).__init__()

        out1 = 2
        out2 = 20
        kernel1 = (1,3)
        kernel2 = (1,cols-2) # cols - (kernel1[1] - 1)

        self.conv1 = nn.Conv2d(in_features, out1, kernel1)
        self.conv2 = nn.Conv2d(out1, out2, kernel2)
        self.votes = nn.Conv2d(out2+1, 1, (1,1)) # input features is out2 plus the appended last_weights
        
        # BTC bias
        b = torch.zeros((1,1)) #requires_grad=True)
        self.b = nn.Parameter(b)

    def forward(self, x, w):
        x = self.conv1(x)
        x = F.relu(x)
        x = self.conv2(x)
        x = F.relu(x)
        x = torch.cat((x,w),dim=1)
        x = self.votes(x)
        x = torch.squeeze(x)
        
        cash = self.b.repeat(x.size()[0], 1)
        
        x = torch.cat((cash, x), dim=1)
        x = F.softmax(x, dim=1)
        
        return x

In [3]:
feat = 2
window = 10
coins = 5
x = torch.rand(16, feat, coins, window)
w = torch.rand(16, coins)
w = w[:,None, : , None]
model = CNN(feat,coins,window)
out = model(x, w)
out.shape

torch.Size([16, 6])

In [9]:
for p in model.parameters():
    print(p.size())

torch.Size([1, 1])
torch.Size([2, 2, 1, 3])
torch.Size([2])
torch.Size([20, 2, 1, 8])
torch.Size([20])
torch.Size([1, 21, 1, 1])
torch.Size([1])


In [5]:
learning_rate = 1e-4
# optimizer = torch.optim.Adam([
#         {'params': model.parameters()},
#         {'params': model.b}
# ], lr=learning_rate)

optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

for t in range(10):
    y = model(x,w)
    loss = y.pow(2).sum()
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

In [6]:
model.b

Parameter containing:
tensor([[0.0010]], requires_grad=True)

In [None]:
w = torch.rand(2,2)
w[:,None,:, None].shape

In [None]:
b = torch.zeros((1,1), requires_grad=True)
b.repeat(5,1).dtype

In [None]:
w.size()[0]