In [25]:
import numpy as np
import matplotlib.pyplot as plt
import torch
import torchvision
import torchvision.transforms as transforms
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from tqdm import tqdm
from typing import Iterable, Sequence, Tuple, List, Union, Optional
import sys

In [26]:
def pytorch_init():
    device_id = 1
    torch.cuda.set_device(device_id)
    
    # Sanity checks
    assert torch.cuda.current_device() == 1, 'Using wrong GPU'
    assert torch.cuda.device_count() == 2, 'Cannot find both GPUs'
    assert torch.cuda.get_device_name(0) == 'GeForce RTX 2080 Ti', 'Wrong GPU name'
    assert torch.cuda.is_available() == True, 'GPU not available'
    return torch.device('cuda', device_id)
    
device = pytorch_init()
device

device(type='cuda', index=1)

In [27]:
img_size = (416, 416)
batch_size = 2
n_channels = 3

X_l = torch.randn(batch_size, n_channels, *img_size, device=device)
X_r = torch.randn(batch_size, n_channels, *img_size, device=device)

print(
    f'X_l: {X_l.shape}\n',
    f'\rX_r: {X_r.shape}'
)

X_l: torch.Size([2, 3, 416, 416])
 X_r: torch.Size([2, 3, 416, 416])


In [92]:
class EncoderBlock(nn.Module):
    def __init__(self, 
            in_channels: int, 
            out_channels: int, 
            kernel_size: int,
            padding: Optional[int] = None,
            **kwargs
        ) -> None:
        '''Defaults to same padding'''
        
        super().__init__()
        
        if padding is None:
            padding = int(kernel_size // 2)
        
        self.layers = nn.Sequential(
            nn.Conv2d(
                in_channels=in_channels, 
                out_channels=out_channels, 
                kernel_size=kernel_size, 
                padding=padding, 
                **kwargs
            ),
            nn.MaxPool2d(kernel_size=2, stride=2),
            nn.ReLU(inplace=True),
            nn.BatchNorm2d(num_features=out_channels),
        )
        
    def forward(self, x):
        return self.layers(x)
        

class Encoder(nn.Module):
    def __init__(self) -> None:
        super().__init__()
        
        self.blocks = nn.Sequential(
            EncoderBlock( 3, 96, 7),
            EncoderBlock(96, 32, 5),
            EncoderBlock(32, 32, 3),
        )
        
        self.latent = EncoderBlock(32*2, 16, 1)
        
    def forward(self, x: Tuple[torch.Tensor, torch.Tensor]):
        left, right = self.blocks(x[0]), self.blocks(x[1])
        latent = self.latent(torch.cat((left, right), dim=1))
        return latent

In [93]:
class Decoder(nn.Module):
    def __init__(self):
        super().__init__()
        self.decoder_block = nn.Sequential(
            nn.Linear(in_features=16*26*26, out_features=100),
            nn.ReLU(inplace=True),
            nn.Linear(in_features=100, out_features=10),
        )
        
    def forward(self, x: torch.Tensor):
        return self.decoder_block(x)

In [94]:
class FishNet(nn.Module):
    def __init__(self):
        super().__init__()
        
        self.encoder = Encoder()
        self.decoder = Decoder()
        
    def forward(self, x: Tuple[torch.Tensor, torch.Tensor]):
        x = self.encoder(x)
        # .size is same as reshape
        # .size(0) is respective to batch size
        x = x.view(x.size(0), -1)
        x = self.decoder(x)
        return x

In [95]:
model = FishNet().to(device)
model

FishNet(
  (encoder): Encoder(
    (blocks): Sequential(
      (0): EncoderBlock(
        (layers): Sequential(
          (0): Conv2d(3, 96, kernel_size=(7, 7), stride=(1, 1), padding=(3, 3))
          (1): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
          (2): ReLU(inplace=True)
          (3): BatchNorm2d(96, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        )
      )
      (1): EncoderBlock(
        (layers): Sequential(
          (0): Conv2d(96, 32, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
          (1): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
          (2): ReLU(inplace=True)
          (3): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        )
      )
      (2): EncoderBlock(
        (layers): Sequential(
          (0): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
          (1): MaxPool2d(kernel_size=2, stride=2, padding=0, d

In [96]:
with torch.no_grad():
    model.eval()
    print(model((X_l, X_r)).shape)

torch.Size([2, 10])


In [16]:
# 'a' MUST BE IN CORRECT DTYPE OR PYTORCH WILL SHIT ITSELF WTF
a = torch.arange((2), device=device, dtype=torch.int64)

def train_gen(n: int):
    for i in range(n):    
        yield ((X_l, X_r), a)

        
def train(model: nn.Module, n_epochs: int):
    model.train()
    optimizer = torch.optim.Adam(model.parameters())
    criterion = nn.CrossEntropyLoss()
    
    for epoch in range(n_epochs):
        pbar = tqdm(
            iterable = enumerate(range(10)),
            total = 10,
            unit = ' batches',
            desc=f' Epoch {epoch+1}/{n_epochs}',
#             file=sys.stdout,
            ascii=True,
            position=0
        )
        
        running_loss: float = 0.0

        i: int
        minibatch: Tuple[torch.Tensor, torch.Tensor]
        tqdm_dict = {'loss':0}
        for i, minibatch in pbar:
            
            optimizer.zero_grad()
            outputs = model((X_l, X_r))
            loss = criterion(outputs, a)
            loss.backward()
            optimizer.step()
            
            running_loss += loss.item()
            tqdm_dict['loss'] = running_loss
            pbar.set_postfix(tqdm_dict)
            
        
train(model, 10)

 Epoch 1/10: 100%|##########| 10/10 [00:00<00:00, 25.15 batches/s, loss=2.55]
 Epoch 2/10: 100%|##########| 10/10 [00:00<00:00, 27.81 batches/s, loss=1.19e-6]
 Epoch 3/10: 100%|##########| 10/10 [00:00<00:00, 27.84 batches/s, loss=5.96e-7]
 Epoch 4/10: 100%|##########| 10/10 [00:00<00:00, 27.84 batches/s, loss=5.96e-7]
 Epoch 5/10: 100%|##########| 10/10 [00:00<00:00, 27.83 batches/s, loss=5.96e-7]
 Epoch 6/10: 100%|##########| 10/10 [00:00<00:00, 27.69 batches/s, loss=2.38e-7]
 Epoch 7/10: 100%|##########| 10/10 [00:00<00:00, 27.83 batches/s, loss=0]
 Epoch 8/10: 100%|##########| 10/10 [00:00<00:00, 27.91 batches/s, loss=0]
 Epoch 9/10: 100%|##########| 10/10 [00:00<00:00, 27.79 batches/s, loss=0]
 Epoch 10/10: 100%|##########| 10/10 [00:00<00:00, 27.80 batches/s, loss=0]


<br/>
<br/>
<br/>
<br/>
<br/>
<br/>
<br/>
<br/>
<br/>
<br/>
<br/>
<br/>
<br/>
<br/>
<br/>
<br/>
<br/>
<br/>
<br/>
<br/>
<br/>
<br/>
<br/>
<br/>
<br/>
<br/>
<br/>
<br/>
<br/>
<br/>
<br/>
<br/>
<br/>
<br/>
<br/>
<br/>
<br/>
<br/>
<br/>
<br/>
<br/>
<br/>
<br/>
<br/>
<br/>
<br/>
<br/>
<br/>
<br/>
<br/>
<br/>
<br/>
<br/>
<br/>
<br/>
<br/>
<br/>
<br/>
<br/>
<br/>
<br/>
<br/>
<br/>
<br/>
<br/>
<br/>
<br/>
<br/>
<br/>
<br/>
<br/>
<br/>

Trial stuff

In [10]:
raise ValueError

ValueError: 

In [55]:
enc = Encoder().to(device)
print(enc)

enc_output = enc((X_l, X_r))
enc_output.shape

# dec = Decoder()
# print(dec)

# # .size(0) is respective to batch size
# dec_output = dec(enc_output.view(enc_output.size(0), -1))
# dec_output.shape

Encoder(
  (blocks): Sequential(
    (0): EncoderBlock(
      (layers): Sequential(
        (0): Conv2d(3, 96, kernel_size=(7, 7), stride=(1, 1), padding=(3, 3))
        (1): MaxPool2d(kernel_size=7, stride=2, padding=0, dilation=1, ceil_mode=False)
        (2): ReLU(inplace=True)
        (3): BatchNorm2d(96, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
    )
    (1): EncoderBlock(
      (layers): Sequential(
        (0): Conv2d(96, 32, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
        (1): MaxPool2d(kernel_size=5, stride=2, padding=0, dilation=1, ceil_mode=False)
        (2): ReLU(inplace=True)
        (3): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
    )
    (2): EncoderBlock(
      (layers): Sequential(
        (0): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (1): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)
        (2): ReLU(inplace=True)
    

torch.Size([2, 16, 25, 25])