In [1]:
import matplotlib.pyplot as plt
import numpy as np
import random
from PIL import Image
import PIL.ImageOps    

import torchvision
import torchvision.datasets as datasets
import torchvision.transforms as transforms
from torch.utils.data import DataLoader, Dataset
import torchvision.utils
import torch
from torch.autograd import Variable
import torch.nn as nn
from torch import optim
import torch.nn.functional as F

In [2]:
class A (nn.Module):
    def __init__(self):
        super().__init__()

        self.fcn1 = nn.Linear(128, 512)
        self.act = nn.ReLU(inplace=True)
        self.fcn2 = nn.Linear(512, 1024)
        self.fcn3 = nn.Linear(1024, 512)
        self.fcn4 = nn.Linear(512, 128)

    def forward (self, x):
        x = self.fcn1(x)
        x = self.act(x)
        x = self.fcn2(x)
        x = self.act(x)
        x = self.fcn3(x)
        x = self.act(x)
        x = self.fcn4(x)
        return x
    
class B (nn.Module):
    def __init__(self):
        super().__init__()

        self.A = A()
        self.fcn1 = nn.Linear(128, 10)
        self.act = nn.ReLU(inplace=True)
        self.sig = nn.Sigmoid()

    def forward (self, x):
        x = self.A(x)
        x = self.fcn1(x)
        x = self.sig(x)
        return x


model = B()
optimizer = optim.Adam(model.parameters(), lr=0.01)
num_epochs = 10

In [4]:
model.parameters

<bound method Module.parameters of B(
  (A): A(
    (fcn1): Linear(in_features=128, out_features=512, bias=True)
    (act): ReLU(inplace=True)
    (fcn2): Linear(in_features=512, out_features=1024, bias=True)
    (fcn3): Linear(in_features=1024, out_features=512, bias=True)
    (fcn4): Linear(in_features=512, out_features=128, bias=True)
  )
  (fcn1): Linear(in_features=128, out_features=10, bias=True)
  (act): ReLU(inplace=True)
  (sig): Sigmoid()
)>

In [2]:
class CNN_select(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size=3):
        super(CNN_select, self).__init__()

        if kernel_size==3:
            self.cnn = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=False)
        else:
            self.cnn = nn.Conv2d(in_channels, out_channels, kernel_size=2, stride=2, padding=0, bias=False)
        self.bn = nn.BatchNorm2d(out_channels)
        self.act = nn.ReLU(inplace=True)

    
    def forward (self, x):
        out = self.cnn(x)
        out = self.bn(out)
        out = self.act(out)

        return out
        

def CNN_Block (cnn1, cnn2, x):
    out = cnn1(x)
    out = cnn2(out)
    return out


In [3]:
IMG_CHANNELS = 3

In [15]:

class MultiScale_SNet_ShareWt (nn.Module):
    def __init__ (self):
        super(MultiScale_SNet_ShareWt, self).__init__()
        
        self.cnn_3 = nn.ModuleList([])
        in_out_channels_3 = [(IMG_CHANNELS,8), (8,8), (16,16), (32,32), (64,64)]
        for i in range(5):
            self.cnn_3.append( CNN_select(in_out_channels_3[i][0], in_out_channels_3[i][1], 3) )
        
        self.cnn_2 = nn.ModuleList([])
        in_out_channels_2 = [(8,16), (16,32), (32,64), (64,128)]
        for i in range(4):
            self.cnn_2.append( CNN_select(in_out_channels_2[i][0], in_out_channels_2[i][1], 2) )

        self.convT = nn.ConvTranspose2d( 128, IMG_CHANNELS, 16, 16 )

        # self.alpha = nn.Parameter(torch.tensor(np.random.random_sample(), requires_grad=True).float())
        self.alpha = nn.Parameter(torch.tensor(0.5, requires_grad=True))
        self.beta  = nn.Parameter(torch.tensor(np.random.random_sample(), requires_grad=True).float())
        self.gamma = nn.Parameter(torch.tensor(np.random.random_sample(), requires_grad=True).float())
        self.delta = nn.Parameter(torch.tensor(np.random.random_sample(), requires_grad=True).float())

        self.max_pool_3d_8 = nn.MaxPool3d((8,1,1), (8,1,1), 0)
        self.max_pool_3d_4 = nn.MaxPool3d((4,1,1), (4,1,1), 0)
        self.max_pool_3d_2 = nn.MaxPool3d((2,1,1), (2,1,1), 0)

        self.mul = torch.multiply

    def forward_once (self, img):
        
        out = self.cnn_3[0](img)
        
        out = CNN_Block( self.cnn_3[1], self.cnn_2[0], out )
        x1 = out.clone()
        x1 = self.max_pool_3d_8(x1)
        x1 = x1.reshape((-1, 128, 16, 16))
        x1 = self.mul(self.alpha, x1)

        out = CNN_Block( self.cnn_3[2], self.cnn_2[1], out )
        x2 = out.clone()
        x2 = self.max_pool_3d_4(x2)
        x2 = x2.reshape((-1, 128, 16, 16))
        x2 = self.mul(self.beta, x2)

        out = CNN_Block( self.cnn_3[3], self.cnn_2[2], out )
        x3 = out.clone()
        x3 = self.max_pool_3d_2(x3)
        x3 = x3.reshape((-1, 128, 16, 16))
        x3 = self.mul(self.gamma, x3)

        out = CNN_Block( self.cnn_3[4], self.cnn_2[3], out )
        out = self.mul(self.delta, out)
        
        return x1, x2, x3, out

    def forward_feature_extr (self, img1, img2):
        x1, x2, x3, x4 = self.forward_once(img1)
        x5, x6, x7, x8 = self.forward_once(img2)

        out  = x4 + x8
        out += x3 + x7
        out += x2 + x6
        out += x1 + x5

        return out
    
    def forward (self, img1, img2):
        out = self.forward_feature_extr(img1, img2)
        out = self.convT(out)

        return out

In [16]:
x1 = torch.randn(1,3,256,256).float()
x2 = torch.randn(1,3,256,256).float()
x1.shape, x1.dtype

(torch.Size([1, 3, 256, 256]), torch.float32)

In [17]:
net = MultiScale_SNet_ShareWt()
y = net(x1, x2)
y.shape

torch.Size([1, 3, 256, 256])

In [19]:
net.requires_grad_

<bound method Module.requires_grad_ of MultiScale_SNet_ShareWt(
  (cnn_3): ModuleList(
    (0): CNN_select(
      (cnn): Conv2d(3, 8, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn): BatchNorm2d(8, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (act): ReLU(inplace=True)
    )
    (1): CNN_select(
      (cnn): Conv2d(8, 8, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn): BatchNorm2d(8, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (act): ReLU(inplace=True)
    )
    (2): CNN_select(
      (cnn): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (act): ReLU(inplace=True)
    )
    (3): CNN_select(
      (cnn): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_s

In [12]:
class Dan(nn.Module):
    def __init__(self):
        super(Dan, self).__init__()
        self.alpha = nn.Parameter(torch.tensor(0.5, requires_grad=True))

In [25]:
model = Dan()
for param in model.parameters():
    param.requires_grad = True
model.requires_grad_

<bound method Module.requires_grad_ of Dan()>

In [None]:
z

In [None]:
x1

In [None]:
torch.multiply(x1, z)

In [None]:
net.zero_grad()

In [3]:
from torch import diff, abs, concat
from torch import nn
import torch

In [4]:
a = torch.tensor([1,2,3])
b = torch.tensor([4,5,6])

In [11]:
z = torch.abs(torch.sub( a,b ))

In [12]:
z

tensor([3, 3, 3])

In [14]:
a[a >= 2] = 3
a

tensor([1, 3, 3])

In [21]:
a = torch.abs(torch.rand(1,3,4,4))
a

tensor([[[[0.0555, 0.3193, 0.3099, 0.2801],
          [0.1767, 0.9785, 0.6744, 0.1911],
          [0.7703, 0.5647, 0.7983, 0.4751],
          [0.7458, 0.9581, 0.8545, 0.3156]],

         [[0.5464, 0.4841, 0.4404, 0.3347],
          [0.2849, 0.7475, 0.5635, 0.7833],
          [0.9801, 0.0139, 0.6048, 0.0775],
          [0.3822, 0.5509, 0.9224, 0.6089]],

         [[0.5026, 0.7549, 0.2861, 0.0518],
          [0.6243, 0.1474, 0.3982, 0.8210],
          [0.4324, 0.8772, 0.6539, 0.4765],
          [0.8252, 0.1919, 0.1516, 0.1080]]]])

In [22]:
a[a>=0.7] = 1
a[a<0.7] = 0
a

tensor([[[[0., 0., 0., 0.],
          [0., 1., 0., 0.],
          [1., 0., 1., 0.],
          [1., 1., 1., 0.]],

         [[0., 0., 0., 0.],
          [0., 1., 0., 1.],
          [1., 0., 0., 0.],
          [0., 0., 1., 0.]],

         [[0., 1., 0., 0.],
          [0., 0., 0., 1.],
          [0., 1., 0., 0.],
          [1., 0., 0., 0.]]]])