In [8]:
import numpy as np
import matplotlib.pyplot as plt
import torch
import torchvision
import torchvision.transforms as transforms
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from tqdm import tqdm
from typing import Iterable, Sequence, Tuple, List
import sys

In [2]:
def pytorch_init():
    device_id = 1
    torch.cuda.set_device(device_id)
    
    # Sanity checks
    assert torch.cuda.current_device() == 1, 'Using wrong GPU'
    assert torch.cuda.device_count() == 2, 'Cannot find both GPUs'
    assert torch.cuda.get_device_name(0) == 'GeForce RTX 2080 Ti', 'Wrong GPU name'
    assert torch.cuda.is_available() == True, 'GPU not available'
    return torch.device('cuda', device_id)
    
device = pytorch_init()
device

device(type='cuda', index=1)

In [6]:
img_size = (416, 416)

In [102]:
class Encoder(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv_blocks = nn.Sequential(
            nn.Conv2d(in_channels=3, out_channels=20, kernel_size=(5,5), padding=2),
            nn.Conv2d(in_channels=20, out_channels=40, kernel_size=(3,3), padding=1)
        )
        
    def forward(self, x: torch.Tensor):
        left, right = self.conv_blocks(x), self.conv_blocks(x)
        return self.conv_blocks(x)
        
class Decoder(nn.Module):
    def __init__(self):
        super().__init__()
        self.decoder_block = nn.Sequential(
            nn.Linear(in_features=416*416*3, out_features=500),
            nn.ReLU(),
            nn.Linear(in_features=500, out_features=10)
        )
        
    def forward(self, x: torch.Tensor):
        return self.decoder_block(x)

class FishNet(nn.Module):
    def __init__(self):
        super().__init__()
        
        self.encoder = Encoder()
        self.decoder = Decoder()
        
    def forward(self, x: Tuple[torch.Tensor, torch.Tensor]):
        x = self.encoder()
        x = x.view(x.size(0), -1) # Kinda like ravel
        x = self.decoder(x)
        return x
        

In [103]:
enc = Encoder()
enc

Encoder(
  (conv_blocks): Sequential(
    (0): Conv2d(3, 20, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
    (1): Conv2d(20, 40, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  )
)

In [104]:
X = torch.randn(4, 3, 416, 416)

In [105]:
X[0].shape

torch.Size([3, 416, 416])

In [106]:
enc(X).shape

torch.Size([4, 40, 416, 416])

In [20]:
model = FishNet()

In [22]:
model

FishNet(
  (encoder): Encoder(
    (conv_blocks): Sequential(
      (0): Conv2d(3, 20, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
      (1): Conv2d(10, 40, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    )
  )
  (decoder): Decoder(
    (decoder_block): Sequential(
      (0): Linear(in_features=519168, out_features=500, bias=True)
      (1): ReLU()
      (2): Linear(in_features=500, out_features=10, bias=True)
    )
  )
)

<br/>
<br/>
<br/>
<br/>
<br/>
<br/>
<br/>
<br/>
<br/>
<br/>
<br/>
<br/>
<br/>
<br/>
<br/>
<br/>
<br/>
<br/>
<br/>
<br/>
<br/>
<br/>
<br/>
<br/>
<br/>
<br/>
<br/>
<br/>
<br/>
<br/>
<br/>
<br/>
<br/>
<br/>
<br/>
<br/>
<br/>
<br/>
<br/>
<br/>
<br/>
<br/>

Trial stuff

In [29]:
a = torch.arange(0,4*4).view((4,4))
b = torch.arange(4*4,4*4+4*4).view((4,4))

In [58]:
a

tensor([[ 0,  1,  2,  3],
        [ 4,  5,  6,  7],
        [ 8,  9, 10, 11],
        [12, 13, 14, 15]])

In [59]:
b

tensor([[16, 17, 18, 19],
        [20, 21, 22, 23],
        [24, 25, 26, 27],
        [28, 29, 30, 31]])

In [55]:
torch.stack((a,b), dim=0)

tensor([[[ 0,  1,  2,  3],
         [ 4,  5,  6,  7],
         [ 8,  9, 10, 11],
         [12, 13, 14, 15]],

        [[16, 17, 18, 19],
         [20, 21, 22, 23],
         [24, 25, 26, 27],
         [28, 29, 30, 31]]])

In [61]:
torch.stack((a,b), dim=1)

tensor([[[ 0,  1,  2,  3],
         [16, 17, 18, 19]],

        [[ 4,  5,  6,  7],
         [20, 21, 22, 23]],

        [[ 8,  9, 10, 11],
         [24, 25, 26, 27]],

        [[12, 13, 14, 15],
         [28, 29, 30, 31]]])

In [62]:
torch.stack((a,b), dim=2)

tensor([[[ 0, 16],
         [ 1, 17],
         [ 2, 18],
         [ 3, 19]],

        [[ 4, 20],
         [ 5, 21],
         [ 6, 22],
         [ 7, 23]],

        [[ 8, 24],
         [ 9, 25],
         [10, 26],
         [11, 27]],

        [[12, 28],
         [13, 29],
         [14, 30],
         [15, 31]]])