In [8]:
import torch
from torch import nn
from torchvision import transforms as T
import utils as ut
import datasets
import os
import math
import collections 

In [4]:
dev = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
print(dev)

cuda


In [5]:
data_dir = "data_h3_w3"
train_dir = os.path.join(data_dir, "basic_train")
valid_dir = os.path.join(data_dir, "valid")
test_dir = os.path.join(data_dir, "test")

In [11]:
BlockArgs = collections.namedtuple("BlockArgs", [
    "kernel_size", "num_repeat", "in_channels", "out_channels",
    "stride", "activation_fun"
])

In [12]:
def conv1x1(in_channels: int, out_channels: int, stride: int = 1) -> nn.Conv2d:
    return nn.Conv2d(in_channels, out_channels, 1, stride=stride)

def conv(in_channels: int, out_channels: int, kernel_size: int, stride: int = 1) -> nn.Conv2d:
    return nn.Conv2d(in_channels, out_channels, kernel_size, stride=stride, padding=1)



#self, in_channels, out_channels, stride=1, activation_fun = nn.ReLU

class ConvBlock(nn.Module):
    def __init__(self, block_args: BlockArgs):
        super(ConvBlock, self).__init__()
        self.cnn1 = conv(block_args.in_channels, block_args.out_channels, block_args.kernel_size, 
                         block_args.stride)
        self.bn1 = nn.BatchNorm2d(block_args.out_channels)
        self.activation_fun = block_args.activation_fun()
        self.cnn2 = conv(block_args.out_channels, block_args.out_channels, block_args.kernel_size)
        self.bn2 = nn.BatchNorm2d(block_args.out_channels)
        
        self.skip = nn.Sequential(
            conv1x1(block_args.in_channels, block_args.out_channels, block_args.stride), 
            nn.BatchNorm2d(block_args.out_channels)
        )
        
        
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        identity = self.skip(x)
        
        out = self.cnn1(x)
        out = self.bn1(out)
        out = self.activation_fun(out)
        
        out = self.cnn2(out)
        out = self.bn2(out)
        
        out += identity
        out = self.activation_fun(out)
        
        return out

    

class Endurance(nn.Module):
    def __init__(self, blocks_args: list[BlockArgs]):
        super(Endurance, self).__init__()
        self.block_args = block_args
        
        
    def build(self):
        pass
        