In [17]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class ConvNet(nn.Module):
	def __init__(self, inchannels, outchannels, hchannels, kernel=2, nonlin=nn.ReLU(), final_nonlin=nn.Identity()):
		super(ConvNet, self).__init__()
		
		self.in_channels, self.out_channels = inchannels, outchannels
		self.nhidden = len(hchannels)
		channels = [inchannels] + hchannels + [outchannels]
		self.nonlin = [nonlin for k in range(self.nhidden)] + [final_nonlin]
		self.conv = nn.ModuleList(
			[
				nn.ConvTranspose1d(channels[k], channels[k+1], kernel, stride=2) for k in range(self.nhidden + 1)
			]
		)
	def forward(self, x):
		for conv, nlin in zip(self.conv, self.nonlin):
			print('* '*10)
			print(x)
			print(x.shape)
			print(conv)
			print(nlin)
			print('* '*10)
			x = nlin(conv(x))
		return x


class ImmDiff(nn.Module):
    def __init__(self):
        super(ImmDiff, self).__init__()

        self.nurbs_to_img = ConvNet(1000, 32, [500 for i in range(3)], nonlin=torch.sin)

        self.up_conv_1 = nn.ConvTranspose2d(1, 1, kernel_size=2, stride=2)
        self.up_conv_2 = nn.ConvTranspose2d(1, 1, kernel_size=2, stride=2)
    
    def forward(self, x):
        x = F.tanh(self.nurbs_to_img(x)).unsqueeze(1)
        print(x.shape)
        x = F.tanh(self.up_conv_1(x))
        print(x.shape)
        return self.up_conv_2(x)

In [21]:
x = torch.ones((16,1,1000,2))
# network = ConvNet(1000, 32, [500 for i in range(3)], nonlin=torch.sigmoid)
network = ImmDiff()
y_hat = network(x)

* * * * * * * * * * 
tensor([[[[1., 1.],
          [1., 1.],
          [1., 1.],
          ...,
          [1., 1.],
          [1., 1.],
          [1., 1.]]],


        [[[1., 1.],
          [1., 1.],
          [1., 1.],
          ...,
          [1., 1.],
          [1., 1.],
          [1., 1.]]],


        [[[1., 1.],
          [1., 1.],
          [1., 1.],
          ...,
          [1., 1.],
          [1., 1.],
          [1., 1.]]],


        ...,


        [[[1., 1.],
          [1., 1.],
          [1., 1.],
          ...,
          [1., 1.],
          [1., 1.],
          [1., 1.]]],


        [[[1., 1.],
          [1., 1.],
          [1., 1.],
          ...,
          [1., 1.],
          [1., 1.],
          [1., 1.]]],


        [[[1., 1.],
          [1., 1.],
          [1., 1.],
          ...,
          [1., 1.],
          [1., 1.],
          [1., 1.]]]])
torch.Size([16, 1, 1000, 2])
ConvTranspose1d(1000, 500, kernel_size=(2,), stride=(2,))
<built-in method sin of type object at 0x7f7

RuntimeError: Expected 3-dimensional input for 3-dimensional weight [1000, 500, 2], but got 4-dimensional input of size [16, 1, 1000, 2] instead

In [20]:
print(y_hat.shape)

torch.Size([2, 1, 128, 128])
