In [17]:
import numpy as np
import matplotlib.pyplot as plt
import torch
import torchvision
import torchvision.transforms as transforms
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from tqdm import tqdm
from typing import Iterable, Sequence, Tuple, List, Union, Optional, Dict
import sys
import generators

In [2]:
def pytorch_init():
    device_id = 1
    torch.cuda.set_device(device_id)
    
    # Sanity checks
    assert torch.cuda.current_device() == 1, 'Using wrong GPU'
    assert torch.cuda.device_count() == 2, 'Cannot find both GPUs'
    assert torch.cuda.get_device_name(0) == 'GeForce RTX 2080 Ti', 'Wrong GPU name'
    assert torch.cuda.is_available() == True, 'GPU not available'
    return torch.device('cuda', device_id)

try:
    device = pytorch_init()
except AssertionError as e:
    print('GPU could not initialize, got error:', e)
    device = torch.device('cpu')

print('Using device:', device)

Using device: cuda:1


In [3]:
img_size = (416, 416)
batch_size = 2
n_channels = 3

X_l = torch.randn(batch_size, n_channels, *img_size, device=device)
X_r = torch.randn(batch_size, n_channels, *img_size, device=device)

print(
    f'X_l: {X_l.shape}\n',
    f'\rX_r: {X_r.shape}'
)

X_l: torch.Size([2, 3, 416, 416])
 X_r: torch.Size([2, 3, 416, 416])


In [4]:
class EncoderBlock(nn.Module):
    
    prev_args: Dict[str, Union[None, int]] = {
        'out_channels': None, 
    }
        
    NotInitializedError = ValueError('EncoderBlock has not been initialized before, cannot infer in_channels')
    
    def __init__(self, 
            in_channels: int, 
            out_channels: int,
            kernel_size: int,
            padding: Optional[int] = None,
            **kwargs
        ) -> None:
        '''Defaults to same padding'''
        
        super().__init__()
            
        if in_channels == -1:
            in_channels = EncoderBlock.prev_args['out_channels']
            if in_channels is None:
                raise NotInitializedError
        EncoderBlock.prev_args['out_channels'] = out_channels
        
        if padding is None:
            padding = int(kernel_size // 2)
        
        self.layers = nn.Sequential(
            nn.Conv2d(
                in_channels=in_channels,
                out_channels=out_channels,
                kernel_size=kernel_size,
                padding=padding,
                **kwargs
            ),
            nn.MaxPool2d(kernel_size=2, stride=2),
            nn.ReLU(inplace=True),
            nn.BatchNorm2d(num_features=out_channels),
        )
        
    def forward(self, x):
        return self.layers(x)
        

class Encoder(nn.Module):
    
    def __init__(self, input_shape: Tuple[int, int, int]) -> None:
        '''Input shape is used for assertion and infering decoder sizes'''
        super().__init__()
        
        if len(input_shape) != 3:
            raise ValueError('Input shape should be tuple of three integers: channels, height, width')
        
        self.blocks = nn.Sequential(
            EncoderBlock(input_shape[0], 16, 7),
            EncoderBlock(-1, 16, 5),
            EncoderBlock(-1, 16, 3),
            EncoderBlock(-1, 16, 3),
        )
        
        self.latent = EncoderBlock(16*2, 16, 1)
        
        self.nMaxPool2d: int = len(self.blocks)+1
        self.input_shape: np.ndarray = np.array(input_shape, dtype=np.int)
        
        
    def forward(self, x: Tuple[torch.Tensor, torch.Tensor]):    
        left, right = self.blocks(x[0]), self.blocks(x[1])
        latent = self.latent(torch.cat((left, right), dim=1))    
        return latent

In [5]:
class Decoder(nn.Module):
    def __init__(self, encoder: Encoder):
        '''Takes in encoder object in order to infer sizes'''
        super().__init__()
    
        in_features = 16*(encoder.input_shape[1:]/(2**encoder.nMaxPool2d)).prod(dtype=np.int)
        
        self.decoder_block = nn.Sequential(
            nn.Linear(in_features=in_features, out_features=100),
            nn.ReLU(inplace=True),
            nn.Linear(in_features=100, out_features=10),
        )
        
    def forward(self, x: torch.Tensor):
        return self.decoder_block(x)

In [6]:
class FishNet(nn.Module):
    def __init__(self):
        super().__init__()
        
        self.encoder = Encoder(input_shape=(3, 416, 416))
        self.decoder = Decoder(self.encoder)
        
    def forward(self, x: Tuple[torch.Tensor, torch.Tensor]):
        x = self.encoder(x)
        # .size is same as reshape
        # .size(0) is respective to batch size
        x = x.view(x.size(0), -1)
        x = self.decoder(x)
        return x

In [7]:
model = FishNet().to(device)
model

FishNet(
  (encoder): Encoder(
    (blocks): Sequential(
      (0): EncoderBlock(
        (layers): Sequential(
          (0): Conv2d(3, 16, kernel_size=(7, 7), stride=(1, 1), padding=(3, 3))
          (1): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
          (2): ReLU(inplace=True)
          (3): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        )
      )
      (1): EncoderBlock(
        (layers): Sequential(
          (0): Conv2d(16, 16, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
          (1): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
          (2): ReLU(inplace=True)
          (3): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        )
      )
      (2): EncoderBlock(
        (layers): Sequential(
          (0): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
          (1): MaxPool2d(kernel_size=2, stride=2, padding=0, d

In [8]:
with torch.no_grad():
    model.eval()
    print(model((X_l, X_r)).shape)

torch.Size([2, 10])


In [11]:
# 'a' MUST BE IN CORRECT DTYPE OR PYTORCH WILL SHIT ITSELF WTF
a = torch.arange((2), device=device, dtype=torch.int64)

def train_gen(n: int):
    for i in range(n):    
        yield ((X_l, X_r), a)

        
def train(model: nn.Module, n_epochs: int):
    model.train()
    optimizer = torch.optim.Adam(model.parameters())
    criterion = nn.CrossEntropyLoss()
    
    for epoch in range(n_epochs):
        pbar = tqdm(
            iterable = enumerate(range(10)),
            total = 10,
            unit = ' batches',
            desc=f' Epoch {epoch+1}/{n_epochs}',
#             file=sys.stdout,
            ascii=True,
            position=0
        )
        
        running_loss: float = 0.0

        i: int
        minibatch: Tuple[torch.Tensor, torch.Tensor]
        tqdm_dict = {'loss':0}
        for i, minibatch in pbar:
            optimizer.zero_grad()
            outputs = model((X_l, X_r))
            loss = criterion(outputs, a)
            loss.backward()
            optimizer.step()
            
            running_loss += loss.item()
            tqdm_dict['loss'] = running_loss
            pbar.set_postfix(tqdm_dict)
            
        
train(model, 50)

 Epoch 1/50: 100%|##########| 10/10 [00:00<00:00, 90.30 batches/s, loss=0]
 Epoch 2/50: 100%|##########| 10/10 [00:00<00:00, 112.23 batches/s, loss=0]
 Epoch 3/50: 100%|##########| 10/10 [00:00<00:00, 119.15 batches/s, loss=0]
 Epoch 4/50: 100%|##########| 10/10 [00:00<00:00, 125.18 batches/s, loss=0]
 Epoch 5/50: 100%|##########| 10/10 [00:00<00:00, 125.20 batches/s, loss=0]
 Epoch 6/50: 100%|##########| 10/10 [00:00<00:00, 124.07 batches/s, loss=0]
 Epoch 7/50: 100%|##########| 10/10 [00:00<00:00, 123.74 batches/s, loss=0]
 Epoch 8/50: 100%|##########| 10/10 [00:00<00:00, 124.71 batches/s, loss=0]
 Epoch 9/50: 100%|##########| 10/10 [00:00<00:00, 124.13 batches/s, loss=0]
 Epoch 10/50: 100%|##########| 10/10 [00:00<00:00, 123.89 batches/s, loss=0]
 Epoch 11/50: 100%|##########| 10/10 [00:00<00:00, 124.99 batches/s, loss=0]
 Epoch 12/50: 100%|##########| 10/10 [00:00<00:00, 126.37 batches/s, loss=0]
 Epoch 13/50: 100%|##########| 10/10 [00:00<00:00, 122.30 batches/s, loss=0]
 Epoch 14

In [16]:
model((X_l, X_r)).argmax(axis=1)

tensor([0, 1], device='cuda:1')

<br/>
<br/>
<br/>
<br/>
<br/>
<br/>
<br/>
<br/>
<br/>
<br/>
<br/>
<br/>
<br/>
<br/>
<br/>
<br/>
<br/>
<br/>
<br/>
<br/>
<br/>
<br/>
<br/>
<br/>
<br/>
<br/>
<br/>
<br/>
<br/>
<br/>
<br/>
<br/>
<br/>
<br/>
<br/>
<br/>
<br/>
<br/>
<br/>
<br/>
<br/>
<br/>
<br/>
<br/>
<br/>
<br/>
<br/>
<br/>
<br/>
<br/>
<br/>
<br/>
<br/>
<br/>
<br/>
<br/>
<br/>
<br/>
<br/>
<br/>
<br/>
<br/>
<br/>
<br/>
<br/>
<br/>
<br/>
<br/>
<br/>
<br/>
<br/>
<br/>

Trial stuff

In [10]:
raise NotImplementedError('Guard')

NotImplementedError: Guard

In [None]:
enc = Encoder().to(device)
print(enc)

enc_output = enc((X_l, X_r))
enc_output.shape

# dec = Decoder()
# print(dec)

# # .size(0) is respective to batch size
# dec_output = dec(enc_output.view(enc_output.size(0), -1))
# dec_output.shape